Apache Spark - Data Skew Detection and Solutions
Data skew is the silent killer of Spark job performance. It occurs when data is unevenly distributed across partitions, causing some tasks to process significantly more records than others. While 199...
Key Insights
- Data skew causes a small number of tasks to process disproportionately large amounts of data, turning a 10-minute job into a 2-hour nightmare while most executors sit idle
- Detection is straightforward once you know where to look: the Spark UI’s task duration metrics and a simple
groupBy().count()query will expose problematic keys immediately - Salting remains the most reliable manual fix for severe skew, while Spark 3.0+’s Adaptive Query Execution handles moderate cases automatically with minimal configuration
Introduction to Data Skew
Data skew is the silent killer of Spark job performance. It occurs when data is unevenly distributed across partitions, causing some tasks to process significantly more records than others. While 199 of your 200 tasks complete in 30 seconds, that one straggler chugs along for 45 minutes processing a partition with 100x more data.
The consequences are brutal: wasted cluster resources (executors sitting idle waiting for the straggler), potential out-of-memory errors on overloaded executors, and job runtimes dominated by your slowest task. I’ve seen production jobs go from 15 minutes to 4 hours because of a single skewed key.
Skew happens naturally in real-world data. Customer transactions cluster around popular products. Log events spike during peak hours. Null values accumulate in optional fields. Any time you partition, join, or aggregate on these uneven distributions, you’re asking for trouble.
Identifying Data Skew in Your Jobs
The Spark UI is your first diagnostic tool. Navigate to the Stages tab and look at the task metrics for shuffle-heavy stages. The telltale signs are obvious once you know what to look for:
- Task duration variance: If your median task time is 10 seconds but the max is 10 minutes, you have skew
- Shuffle read size variance: One task reading 5GB while others read 50MB indicates partition imbalance
- GC time spikes: Stragglers often show elevated garbage collection as they struggle with memory pressure
The Summary Metrics section shows min, 25th percentile, median, 75th percentile, and max for each metric. A healthy job shows these values clustered together. Skewed jobs show the max as an extreme outlier.
Before diving into fixes, profile your data to identify the culprits:
from pyspark.sql import functions as F
# Analyze key distribution for a potential join key
key_distribution = (
df.groupBy("customer_id")
.count()
.orderBy(F.desc("count"))
)
# Get statistics
stats = key_distribution.select(
F.count("*").alias("unique_keys"),
F.sum("count").alias("total_records"),
F.max("count").alias("max_per_key"),
F.avg("count").alias("avg_per_key"),
F.percentile_approx("count", 0.99).alias("p99_per_key")
)
stats.show()
# Show the hot keys
print("Top 10 hottest keys:")
key_distribution.show(10)
When the max is 1000x the average, you’ve found your problem. I typically flag any key representing more than 1% of total records as a potential skew risk.
Common Causes of Data Skew
Uneven partition keys are the most common culprit. Null values are notorious—if 30% of your records have a null region_id, all those records end up in the same partition. Popular categories create similar problems: the “Electronics” category might have 10 million products while “Artisanal Cheese Graters” has 12.
Skewed joins amplify the problem. When you join two datasets on a skewed key, the task handling the hot key must process the cartesian product of matching records from both sides.
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import time
spark = SparkSession.builder.appName("SkewDemo").getOrCreate()
# Create a skewed transactions dataset
# 90% of transactions belong to customer_id = 1 (the "hot" key)
transactions_data = (
[(1, f"txn_{i}", 100.0) for i in range(900000)] + # Hot key
[(i, f"txn_{i}", 50.0) for i in range(2, 100001)] # Normal keys
)
transactions = spark.createDataFrame(
transactions_data,
["customer_id", "transaction_id", "amount"]
)
# Customer dimension table
customers = spark.createDataFrame(
[(i, f"Customer_{i}", f"Region_{i % 10}") for i in range(1, 100001)],
["customer_id", "name", "region"]
)
# This join will be heavily skewed
start = time.time()
result = transactions.join(customers, "customer_id")
result.write.format("noop").mode("overwrite").save()
print(f"Skewed join completed in {time.time() - start:.2f} seconds")
Aggregations on low-cardinality columns cause similar issues. Grouping by country_code when 60% of your users are in the US means one reducer handles the majority of your data.
Solution: Salting Technique
Salting artificially increases key cardinality by appending random values. Instead of all records for customer_id = 1 going to one partition, you distribute them across 1_0, 1_1, 1_2, etc.
The trade-off: you must replicate the dimension table for each salt value and perform a final aggregation to combine salted results. It’s more computation, but parallelized computation beats a single-threaded straggler every time.
from pyspark.sql import functions as F
SALT_BUCKETS = 10
# Identify hot keys (keys with > 10000 records)
hot_keys = (
transactions.groupBy("customer_id")
.count()
.filter(F.col("count") > 10000)
.select("customer_id")
.collect()
)
hot_key_set = {row.customer_id for row in hot_keys}
hot_key_broadcast = spark.sparkContext.broadcast(hot_key_set)
# Add salt to transactions (only for hot keys)
def add_salt(customer_id):
import random
if customer_id in hot_key_broadcast.value:
return random.randint(0, SALT_BUCKETS - 1)
return 0
add_salt_udf = F.udf(add_salt)
salted_transactions = transactions.withColumn(
"salt",
add_salt_udf(F.col("customer_id"))
).withColumn(
"salted_key",
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
)
# Explode dimension table for hot keys
customer_with_salt = customers.crossJoin(
spark.range(SALT_BUCKETS).withColumnRenamed("id", "salt")
).withColumn(
"salted_key",
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
)
# For non-hot keys, use original key with salt=0
non_hot_customers = customers.withColumn(
"salt", F.lit(0)
).withColumn(
"salted_key",
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
)
# Combine: exploded hot keys + non-exploded regular keys
hot_customer_ids = [row.customer_id for row in hot_keys]
exploded_customers = customer_with_salt.filter(
F.col("customer_id").isin(hot_customer_ids)
).union(
non_hot_customers.filter(~F.col("customer_id").isin(hot_customer_ids))
)
# Perform the salted join
salted_result = salted_transactions.join(
exploded_customers.select("salted_key", "name", "region"),
"salted_key"
)
Solution: Adaptive Query Execution (AQE)
Spark 3.0 introduced Adaptive Query Execution, which detects and handles skew automatically at runtime. It splits skewed partitions into smaller chunks and replicates the matching partition from the other side of the join.
Enable it with these configurations:
spark = SparkSession.builder \
.appName("AQE Demo") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.enabled", "true") \
.config("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5") \
.config("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB") \
.config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB") \
.getOrCreate()
# With AQE enabled, the same skewed join handles automatically
result = transactions.join(customers, "customer_id")
The skewedPartitionFactor (default: 5) determines how much larger than median a partition must be to qualify as skewed. The skewedPartitionThresholdInBytes (default: 256MB) sets an absolute minimum size threshold.
AQE works well for moderate skew but has limitations. It only handles sort-merge joins (not shuffle hash joins), requires statistics collection, and adds planning overhead. For extreme skew where one key represents 50%+ of your data, manual salting often performs better.
Check if AQE kicked in by examining the SQL tab in Spark UI—look for “CustomShuffleReader” nodes indicating runtime optimization.
Solution: Broadcast Joins and Isolation
When one table is small enough (default threshold: 10MB, configurable up to a few GB), broadcast it to avoid shuffling entirely:
from pyspark.sql import functions as F
# Force broadcast of the smaller table
result = transactions.join(
F.broadcast(customers),
"customer_id"
)
For cases where broadcast isn’t feasible, isolate hot keys for separate processing:
# Separate hot and cold paths
hot_key_list = [1] # Known hot keys
hot_transactions = transactions.filter(
F.col("customer_id").isin(hot_key_list)
)
cold_transactions = transactions.filter(
~F.col("customer_id").isin(hot_key_list)
)
hot_customers = customers.filter(
F.col("customer_id").isin(hot_key_list)
)
cold_customers = customers.filter(
~F.col("customer_id").isin(hot_key_list)
)
# Broadcast join for hot keys (small dimension side)
hot_result = hot_transactions.join(
F.broadcast(hot_customers),
"customer_id"
)
# Regular join for cold keys (no skew)
cold_result = cold_transactions.join(cold_customers, "customer_id")
# Combine results
final_result = hot_result.union(cold_result)
This hybrid approach gives you fine-grained control and often outperforms generic solutions when you know your data’s skew patterns.
Best Practices and Prevention
Profile data before building pipelines. Run distribution analysis on join keys during development, not after production jobs start failing. Add key distribution checks to your data validation suite.
Monitor continuously. Set up alerts for task duration variance. If max task time exceeds 10x median, trigger an investigation. Tools like Spark Measure or custom metrics exporters to Prometheus make this straightforward.
Choose partition keys wisely. Composite keys often distribute better than single columns. (region, date) typically skews less than region alone.
Handle nulls explicitly. Replace nulls with meaningful defaults or filter them for separate processing. Never let nulls silently accumulate in one partition.
Set appropriate parallelism. More partitions mean smaller partition sizes, reducing the impact of moderate skew. Use spark.sql.shuffle.partitions (default 200) scaled to your data volume—I typically use 2-4x the number of cores.
Data skew is inevitable in real-world datasets, but it doesn’t have to destroy your job performance. Detect it early, understand its source, and apply the appropriate fix. Start with AQE for convenience, escalate to salting for severe cases, and always keep broadcast joins in your toolkit for small dimension tables.