How to Pivot a DataFrame in PySpark
Pivoting is one of those operations that seems simple until you need to do it at scale. The concept is straightforward: take values from rows and spread them across columns. You've probably done this...
Key Insights
- Pivoting transforms rows into columns using the
groupBy().pivot().agg()pattern, requiring three components: a grouping column, a pivot column, and an aggregation function. - Always specify explicit pivot values when possible—this prevents Spark from scanning the entire dataset to discover unique values, dramatically improving performance on large datasets.
- High-cardinality pivot columns can cause memory issues and produce unwieldy schemas; filter or bucket your pivot column before transforming, and configure
spark.sql.pivotMaxValuesas a safety net.
Introduction
Pivoting is one of those operations that seems simple until you need to do it at scale. The concept is straightforward: take values from rows and spread them across columns. You’ve probably done this a thousand times in Excel or pandas. But when you’re working with billions of rows in a distributed environment, the rules change.
In PySpark, pivoting is essential for reshaping data for reporting dashboards, preparing features for machine learning, transforming time-series data into wide format, and creating cross-tabulations for analysis. The operation is deceptively powerful—and deceptively easy to get wrong. This article covers the mechanics, the gotchas, and the performance considerations you need to know.
Understanding the Pivot Operation
Before diving into code, let’s establish what pivoting actually does. Consider sales data in a normalized format:
| product | month | revenue |
|---|---|---|
| Widget | Jan | 1000 |
| Widget | Feb | 1200 |
| Gadget | Jan | 800 |
| Gadget | Feb | 950 |
After pivoting on the month column, you get:
| product | Jan | Feb |
|---|---|---|
| Widget | 1000 | 1200 |
| Gadget | 800 | 950 |
Three components make this work:
- Grouping column(s): What defines each row in the output (
product) - Pivot column: The column whose values become new column headers (
month) - Aggregation: How to combine values when multiple rows map to the same cell (
sum,avg, etc.)
Let’s create a sample DataFrame to work with throughout this article:
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum, avg, count, col
spark = SparkSession.builder.appName("PivotDemo").getOrCreate()
data = [
("Widget", "Jan", "North", 1000),
("Widget", "Jan", "South", 1100),
("Widget", "Feb", "North", 1200),
("Widget", "Feb", "South", 1300),
("Widget", "Mar", "North", 900),
("Widget", "Mar", "South", 1050),
("Gadget", "Jan", "North", 800),
("Gadget", "Jan", "South", 750),
("Gadget", "Feb", "North", 950),
("Gadget", "Feb", "South", 880),
("Gadget", "Mar", "North", 1100),
("Gadget", "Mar", "South", 1020),
]
df = spark.createDataFrame(data, ["product", "month", "region", "revenue"])
df.show()
Basic Pivot Syntax
The pivot operation in PySpark follows the groupBy().pivot().agg() pattern. This chained method approach reads naturally: group your data, specify what to pivot on, then define how to aggregate.
pivoted_df = (
df
.groupBy("product")
.pivot("month")
.agg(sum("revenue"))
)
pivoted_df.show()
Output:
+-------+----+----+----+
|product| Feb| Jan| Mar|
+-------+----+----+----+
| Gadget|1830|1550|2120|
| Widget|2500|2100|1950|
+-------+----+----+----+
Notice that Spark automatically discovered the unique values in the month column and created columns for each. The values are sorted alphabetically by default, which is why Feb appears before Jan.
You can also group by multiple columns:
pivoted_by_region = (
df
.groupBy("product", "region")
.pivot("month")
.agg(sum("revenue"))
)
pivoted_by_region.show()
This produces one row per product-region combination, with monthly revenue spread across columns.
Specifying Pivot Values Explicitly
Here’s where many engineers leave performance on the table. When you call .pivot("column") without specifying values, Spark must scan the entire dataset to discover unique values in that column. On a 10GB dataset, that’s an expensive operation you’re running before the actual pivot even begins.
The solution is simple: provide the values explicitly.
months = ["Jan", "Feb", "Mar"]
pivoted_explicit = (
df
.groupBy("product")
.pivot("month", months)
.agg(sum("revenue"))
)
pivoted_explicit.show()
Output:
+-------+----+----+----+
|product| Jan| Feb| Mar|
+-------+----+----+----+
| Gadget|1550|1830|2120|
| Widget|2100|2500|1950|
+-------+----+----+----+
Two benefits here. First, Spark skips the discovery scan entirely. Second, you control the column order—notice Jan now comes before Feb.
This approach also handles missing values gracefully. If you include a month that doesn’t exist in the data, Spark creates the column with nulls:
months_with_april = ["Jan", "Feb", "Mar", "Apr"]
pivoted_with_missing = (
df
.groupBy("product")
.pivot("month", months_with_april)
.agg(sum("revenue"))
)
pivoted_with_missing.show()
The Apr column will contain nulls for all products. This is useful when you need consistent schemas across different data batches.
Multiple Aggregations on Pivoted Data
Real-world analysis rarely needs just one metric. You often want sum and average, or count and max. PySpark handles this elegantly:
from pyspark.sql.functions import sum as spark_sum, avg as spark_avg
multi_agg_pivot = (
df
.groupBy("product")
.pivot("month", ["Jan", "Feb", "Mar"])
.agg(
spark_sum("revenue").alias("total"),
spark_avg("revenue").alias("average")
)
)
multi_agg_pivot.show()
Output:
+-------+---------+-----------+---------+-----------+---------+-----------+
|product|Jan_total|Jan_average|Feb_total|Feb_average|Mar_total|Mar_average|
+-------+---------+-----------+---------+-----------+---------+-----------+
| Gadget| 1550| 775.0| 1830| 915.0| 2120| 1060.0|
| Widget| 2100| 1050.0| 2500| 1250.0| 1950| 975.0|
+-------+---------+-----------+---------+-----------+-----------+-----------+
Spark automatically generates column names by combining the pivot value with the aggregation alias. The naming convention is {pivot_value}_{aggregation_alias}. If you don’t provide an alias, Spark uses the function name, resulting in less readable names like Jan_sum(revenue).
Handling Performance Considerations
Pivoting can be a memory hog. Each unique value in your pivot column becomes a new column in the output. Pivot on a column with 10,000 unique values, and you’ve just created a 10,000-column DataFrame. This causes several problems: shuffle operations become expensive, the driver needs to track an enormous schema, and downstream operations struggle with wide data.
Spark provides a safety valve: spark.sql.pivotMaxValues. By default, it’s set to 10,000. If your pivot column exceeds this threshold, Spark throws an error rather than silently creating an unusable DataFrame.
# Set a lower limit for safety
spark.conf.set("spark.sql.pivotMaxValues", 100)
# This will fail if month has more than 100 unique values
try:
high_cardinality_pivot = (
df
.groupBy("product")
.pivot("some_high_cardinality_column")
.agg(sum("revenue"))
)
except Exception as e:
print(f"Pivot failed: {e}")
The better approach is to filter or bucket high-cardinality columns before pivoting:
# Filter to only the values you need
relevant_months = ["Jan", "Feb", "Mar"]
filtered_pivot = (
df
.filter(col("month").isin(relevant_months))
.groupBy("product")
.pivot("month", relevant_months)
.agg(sum("revenue"))
)
# Or bucket continuous values
from pyspark.sql.functions import when
df_with_buckets = df.withColumn(
"revenue_bucket",
when(col("revenue") < 900, "low")
.when(col("revenue") < 1100, "medium")
.otherwise("high")
)
bucketed_pivot = (
df_with_buckets
.groupBy("product")
.pivot("revenue_bucket", ["low", "medium", "high"])
.agg(count("*"))
)
Unpivoting (Bonus)
Sometimes you need to reverse a pivot—convert columns back to rows. PySpark doesn’t have a built-in unpivot function, but the stack function handles this elegantly.
# Start with our pivoted DataFrame
pivoted_df = (
df
.groupBy("product")
.pivot("month", ["Jan", "Feb", "Mar"])
.agg(sum("revenue"))
)
# Unpivot using stack
unpivoted_df = pivoted_df.selectExpr(
"product",
"stack(3, 'Jan', Jan, 'Feb', Feb, 'Mar', Mar) as (month, revenue)"
)
unpivoted_df.show()
Output:
+-------+-----+-------+
|product|month|revenue|
+-------+-----+-------+
| Gadget| Jan| 1550|
| Gadget| Feb| 1830|
| Gadget| Mar| 2120|
| Widget| Jan| 2100|
| Widget| Feb| 2500|
| Widget| Mar| 1950|
+-------+-----+-------+
The stack(n, ...) function takes n as the number of output rows per input row, followed by alternating key-value pairs. It’s not the most intuitive syntax, but it works reliably.
For dynamic unpivoting when you don’t know the columns ahead of time:
# Get column names to unpivot (excluding the grouping column)
value_columns = [c for c in pivoted_df.columns if c != "product"]
# Build the stack expression dynamically
stack_expr = f"stack({len(value_columns)}, " + ", ".join(
[f"'{c}', `{c}`" for c in value_columns]
) + ") as (month, revenue)"
dynamic_unpivot = pivoted_df.selectExpr("product", stack_expr)
Pivoting in PySpark is straightforward once you understand the pattern. The key is being intentional: specify your pivot values explicitly, watch for high-cardinality columns, and remember that wide DataFrames have their own performance characteristics. Master these fundamentals, and you’ll reshape data confidently at any scale.