PySpark: Handling Skewed Data
Data skew occurs when certain keys in your dataset appear far more frequently than others, causing uneven distribution of work across your Spark cluster. In a perfectly balanced world, each partition...
Key Insights
- Data skew causes a small number of tasks to process disproportionately large partitions, turning a 10-minute job into a 2-hour nightmare with potential out-of-memory failures
- Salting remains the most reliable technique for skewed joins, but Spark 3.0’s Adaptive Query Execution can handle moderate skew automatically with proper configuration
- Always diagnose before optimizing—use
groupBy().count()and Spark UI metrics to confirm skew exists and quantify its severity before implementing fixes
The Data Skew Problem
Data skew occurs when certain keys in your dataset appear far more frequently than others, causing uneven distribution of work across your Spark cluster. In a perfectly balanced world, each partition processes roughly the same amount of data. In reality, you’re dealing with Pareto distributions everywhere.
Consider these common scenarios: e-commerce platforms where 1% of products generate 50% of orders, social networks where celebrity accounts have millions of followers while average users have hundreds, or geographic data where major cities dominate transaction volumes. When you join or aggregate on these skewed keys, Spark assigns all records with the same key to the same partition—and one executor gets crushed while others sit idle.
The symptoms are unmistakable: jobs that should complete in minutes hang for hours, the Spark UI shows one task running while 199 others finished long ago, and eventually you hit OOM errors on that single overwhelmed executor. I’ve seen production jobs fail repeatedly until someone finally diagnosed the underlying skew problem.
Identifying Skewed Data
Before implementing fixes, confirm that skew is actually your problem. The Spark UI provides immediate visibility—navigate to the Stages tab and examine task duration distribution. If the median task time is 30 seconds but the max is 45 minutes, you have severe skew.
Programmatically, analyze your key distribution before running expensive operations:
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("SkewDetection").getOrCreate()
# Load your dataset
orders_df = spark.read.parquet("/data/orders")
# Analyze key distribution
key_distribution = (
orders_df
.groupBy("customer_id")
.agg(F.count("*").alias("record_count"))
.orderBy(F.desc("record_count"))
)
# Calculate skew metrics
stats = key_distribution.agg(
F.mean("record_count").alias("mean_count"),
F.stddev("record_count").alias("stddev_count"),
F.max("record_count").alias("max_count"),
F.expr("percentile_approx(record_count, 0.5)").alias("median_count"),
F.expr("percentile_approx(record_count, 0.99)").alias("p99_count")
).collect()[0]
print(f"Mean: {stats['mean_count']:.2f}")
print(f"Median: {stats['median_count']:.2f}")
print(f"P99: {stats['p99_count']:.2f}")
print(f"Max: {stats['max_count']:.2f}")
print(f"Skew ratio (max/median): {stats['max_count'] / stats['median_count']:.2f}")
# Show the top offenders
print("\nTop 10 skewed keys:")
key_distribution.show(10)
A skew ratio above 100x warrants intervention. If your max key has 10 million records while the median has 100, you’ve found your problem.
Salting Technique
Salting artificially distributes hot keys across multiple partitions by appending random suffixes. This technique works exceptionally well for joins where one side has skewed keys.
The approach: add a random salt value (0 to N-1) to the skewed table’s keys, then replicate the smaller table N times with each salt value. After joining on the salted key, aggregate away the salt.
from pyspark.sql import functions as F
# Configuration
SALT_BUCKETS = 10
# Large table with skewed customer_id
orders_df = spark.read.parquet("/data/orders")
# Smaller lookup table
customers_df = spark.read.parquet("/data/customers")
# Add random salt to the large table
salted_orders = orders_df.withColumn(
"salt",
(F.rand() * SALT_BUCKETS).cast("int")
).withColumn(
"salted_customer_id",
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
)
# Explode the smaller table to match all salt values
salt_array = F.array([F.lit(i) for i in range(SALT_BUCKETS)])
exploded_customers = (
customers_df
.withColumn("salt", F.explode(salt_array))
.withColumn(
"salted_customer_id",
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
)
)
# Perform the salted join
result = salted_orders.join(
exploded_customers,
on="salted_customer_id",
how="inner"
).drop("salt", "salted_customer_id")
# The join now distributes hot keys across SALT_BUCKETS partitions
result.write.parquet("/data/enriched_orders")
The tradeoff: you’re replicating the smaller table N times, increasing memory usage. Choose your salt bucket count based on skew severity—start with 10 and increase if needed. For extreme skew, I’ve used 100+ buckets successfully.
Adaptive Query Execution (AQE)
Spark 3.0 introduced Adaptive Query Execution, which can automatically detect and handle skewed joins at runtime. This is the easiest solution when it works, but it has limitations.
# Enable AQE with skew join optimization
spark = SparkSession.builder \
.appName("AQEOptimization") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5") \
.config("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB") \
.config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB") \
.getOrCreate()
# With AQE enabled, standard joins may handle moderate skew automatically
orders_df = spark.read.parquet("/data/orders")
customers_df = spark.read.parquet("/data/customers")
# AQE will split skewed partitions during execution
result = orders_df.join(customers_df, on="customer_id", how="inner")
# Check if AQE kicked in via the SQL tab in Spark UI
# Look for "CustomShuffleReader" nodes in the query plan
result.explain(mode="formatted")
The skewedPartitionFactor setting (default 5) means a partition is considered skewed if it’s 5x larger than the median partition size AND exceeds the threshold in bytes. Tune these based on your cluster resources and data characteristics.
AQE works well for moderate skew but may not fully resolve extreme cases. Monitor your jobs after enabling it—if you still see straggler tasks, combine AQE with manual salting.
Broadcast Joins for Small Tables
When one side of your join is small enough to fit in executor memory, broadcast joins eliminate shuffles entirely—and with them, any skew concerns on the join operation itself.
from pyspark.sql import functions as F
# Configure broadcast threshold (default is 10MB)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "100MB")
# Small dimension table (under threshold)
products_df = spark.read.parquet("/data/products") # ~50MB
# Large fact table with potentially skewed product_id
transactions_df = spark.read.parquet("/data/transactions") # 500GB
# Explicit broadcast hint for clarity and control
from pyspark.sql.functions import broadcast
result = transactions_df.join(
broadcast(products_df),
on="product_id",
how="inner"
)
# Verify broadcast join was used
result.explain()
# Look for "BroadcastHashJoin" in the plan
Broadcast joins copy the small table to every executor, so the large table’s partitions never need to shuffle. This completely sidesteps skew issues for the join itself. The catch: if your “small” table is actually 2GB, you’ll blow out executor memory. Be conservative with thresholds and monitor memory usage.
Custom Partitioning Strategies
For aggregations on skewed data, two-phase aggregation isolates hot keys and processes them separately with more parallelism.
from pyspark.sql import functions as F
# Identify your hot keys (top N by frequency)
orders_df = spark.read.parquet("/data/orders")
hot_keys = (
orders_df
.groupBy("product_id")
.count()
.orderBy(F.desc("count"))
.limit(100)
.select("product_id")
.collect()
)
hot_key_list = [row["product_id"] for row in hot_keys]
# Split the data
hot_data = orders_df.filter(F.col("product_id").isin(hot_key_list))
normal_data = orders_df.filter(~F.col("product_id").isin(hot_key_list))
# Process hot keys with increased parallelism
hot_aggregated = (
hot_data
.repartition(200, "product_id") # Force more partitions for hot keys
.groupBy("product_id")
.agg(
F.sum("amount").alias("total_amount"),
F.count("*").alias("order_count")
)
)
# Process normal keys with standard parallelism
normal_aggregated = (
normal_data
.groupBy("product_id")
.agg(
F.sum("amount").alias("total_amount"),
F.count("*").alias("order_count")
)
)
# Combine results
final_result = hot_aggregated.union(normal_aggregated)
final_result.write.parquet("/data/product_aggregates")
This approach gives you fine-grained control over resource allocation—throw more partitions at known problem keys while processing the long tail efficiently.
Summary and Best Practices
Choose your skew-handling technique based on your specific situation:
- Small lookup table? Use broadcast joins—they eliminate the problem entirely
- Spark 3.0+ with moderate skew? Enable AQE first—it’s zero-code and often sufficient
- Severe skew in joins? Implement salting—it’s more work but handles extreme cases
- Skewed aggregations? Use two-phase aggregation with separate hot-key processing
Always validate your fix worked. After implementing any technique, check the Spark UI for task duration distribution. Your max task time should now be within 2-3x of the median, not 100x.
Monitor continuously. Skew patterns change as your data grows. That customer who was average last year might become your biggest account tomorrow. Build skew detection into your pipeline monitoring so you catch new hot keys before they cause production incidents.