PySpark - OOM (Out of Memory) Solutions
Out of memory errors in PySpark fall into two distinct categories, and misdiagnosing which one you're dealing with wastes hours of debugging time.
Key Insights
- Most PySpark OOM errors stem from three root causes: data skew, improper memory configuration, or anti-patterns like calling
collect()on large datasets—fix these first before throwing more hardware at the problem. - Driver and executor OOM failures require different solutions; driver issues typically involve collecting too much data, while executor issues point to partition-level problems or insufficient memory allocation.
- The salting technique for skewed joins and strategic use of broadcast joins can eliminate 90% of memory-related performance issues in production Spark jobs.
Understanding OOM Errors in PySpark
Out of memory errors in PySpark fall into two distinct categories, and misdiagnosing which one you’re dealing with wastes hours of debugging time.
Driver OOM occurs when the driver process—the coordinator running your main application—exhausts its memory. You’ll see errors like java.lang.OutOfMemoryError: Java heap space in your driver logs. This almost always means you’re pulling too much data back to a single node via collect(), toPandas(), or similar operations.
Executor OOM happens on the worker nodes processing your data. The error messages reference specific executor IDs and often mention GC overhead or heap space exhaustion. These failures indicate problems at the partition level—either individual partitions are too large, or your transformations require more memory than allocated.
The Spark UI is your first stop for diagnosis. Navigate to the Stages tab and look for tasks with significantly longer execution times or failed attempts. The Storage tab reveals cached data consuming memory, while the Executors tab shows memory usage per executor.
Common root causes I see repeatedly in production:
- Data skew: One partition holds 10x more data than others
- Configuration mismatches: Default memory settings on large datasets
- Anti-pattern abuse: Using
collect()ortoPandas()without size limits - Unbounded aggregations: GroupBy operations creating massive intermediate results
Memory Configuration Tuning
Spark’s default memory settings assume small workloads. Production jobs require explicit configuration.
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("OptimizedMemoryConfig") \
.config("spark.executor.memory", "8g") \
.config("spark.executor.memoryOverhead", "2g") \
.config("spark.driver.memory", "4g") \
.config("spark.driver.memoryOverhead", "1g") \
.config("spark.memory.fraction", "0.8") \
.config("spark.memory.storageFraction", "0.3") \
.config("spark.sql.shuffle.partitions", "200") \
.config("spark.default.parallelism", "200") \
.getOrCreate()
Here’s what each parameter controls:
spark.executor.memory: Heap memory per executor. Start with 4-8GB and scale based on your data.spark.executor.memoryOverhead: Off-heap memory for Python processes, network buffers, and internal metadata. Set to 10-20% of executor memory, minimum 384MB.spark.memory.fraction: Portion of heap used for execution and storage (default 0.6). Increase to 0.8 if you’re not running other JVM processes.spark.memory.storageFraction: Portion ofmemory.fractionreserved for cached data. Lower this if you cache rarely.
A critical mistake: setting executor memory too high. Executors with 32GB+ heap suffer from garbage collection pauses. Use more executors with moderate memory (8-16GB) rather than fewer executors with massive heaps.
Handling Data Skew
Data skew kills Spark jobs silently. One partition with 100 million rows while others have 1 million means one task runs 100x longer—and likely fails.
Detect skew in the Spark UI by examining task durations in the Stages tab. If the max duration is 10x the median, you have skew. The SQL tab also shows partition statistics for shuffle operations.
The salting technique fixes skewed joins by distributing hot keys across multiple partitions:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
import random
# Assume 'orders' has skewed customer_id (some customers have millions of orders)
# and 'customers' is the dimension table
SALT_BUCKETS = 10
# Salt the large (skewed) table
orders_salted = orders.withColumn(
"salt",
(F.rand() * SALT_BUCKETS).cast(IntegerType())
).withColumn(
"customer_id_salted",
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
)
# Explode the small table to match all salt values
customers_exploded = customers.crossJoin(
spark.range(0, SALT_BUCKETS).withColumnRenamed("id", "salt")
).withColumn(
"customer_id_salted",
F.concat(F.col("customer_id"), F.lit("_"), F.col("salt"))
)
# Join on salted keys
result = orders_salted.join(
customers_exploded,
on="customer_id_salted",
how="inner"
).drop("salt", "customer_id_salted")
This distributes a single hot key across 10 partitions, reducing per-partition memory by 10x. The trade-off is replicating the small table, so only salt when you’ve confirmed skew exists.
Optimizing Transformations and Actions
Certain operations are memory landmines. Avoid them or use safer alternatives.
Never use collect() on unbounded data:
# BAD: Will crash on large datasets
all_data = df.collect()
# GOOD: Limit explicitly
sample_data = df.limit(1000).collect()
# GOOD: Use take() for small samples
first_100 = df.take(100)
Broadcast small tables for joins:
from pyspark.sql.functions import broadcast
# Small dimension table (< 100MB)
dim_products = spark.read.parquet("s3://bucket/products")
# Large fact table
fact_sales = spark.read.parquet("s3://bucket/sales")
# Without broadcast: shuffle join (expensive)
result_shuffle = fact_sales.join(dim_products, on="product_id")
# With broadcast: no shuffle, replicated to all executors
result_broadcast = fact_sales.join(
broadcast(dim_products),
on="product_id"
)
Broadcast joins eliminate shuffle entirely. Spark automatically broadcasts tables under 10MB (configurable via spark.sql.autoBroadcastJoinThreshold), but explicit broadcasting ensures the optimization applies.
Replace groupByKey() with reduceByKey() when working with RDDs:
# BAD: Shuffles all values to reducers
rdd.groupByKey().mapValues(sum)
# GOOD: Combines locally before shuffle
rdd.reduceByKey(lambda a, b: a + b)
For DataFrames, prefer built-in aggregation functions over UDFs—they’re optimized and avoid Python serialization overhead.
Efficient Data Partitioning
Partition size directly impacts memory usage. The target is 128MB per partition—large enough to amortize task overhead, small enough to avoid memory pressure.
def repartition_by_size(df, target_partition_mb=128):
"""Repartition DataFrame based on estimated data size."""
# Sample to estimate size (faster than full scan)
sample_fraction = 0.01
sample_df = df.sample(fraction=sample_fraction)
# Cache sample for accurate size estimation
sample_df.cache()
sample_count = sample_df.count()
# Estimate total size from Spark's internal metrics
total_count = df.count()
# Rough estimation: assume average row size from sample
# In practice, use df.rdd.map(lambda r: len(str(r))).mean() for better accuracy
estimated_rows_per_partition = (target_partition_mb * 1024 * 1024) / 500 # assume 500 bytes/row
optimal_partitions = max(1, int(total_count / estimated_rows_per_partition))
sample_df.unpersist()
return df.repartition(optimal_partitions)
# Usage
df_optimized = repartition_by_size(large_df, target_partition_mb=128)
Use repartition() when you need to increase partitions or redistribute data evenly. Use coalesce() when reducing partitions—it avoids a full shuffle by combining existing partitions.
# Reducing partitions: use coalesce (no shuffle)
df_small = df.coalesce(10)
# Increasing partitions or fixing skew: use repartition (full shuffle)
df_even = df.repartition(200)
# Repartition by column for join optimization
df_partitioned = df.repartition(200, "join_key")
Caching and Persistence Best Practices
Caching is not free. Each cached DataFrame consumes memory that could serve computation. Cache only when data is reused multiple times.
from pyspark import StorageLevel
def run_multi_stage_pipeline(raw_df):
# Stage 1: Heavy transformation, reused 3 times
cleaned_df = raw_df \
.filter(F.col("valid") == True) \
.withColumn("normalized", F.lower(F.col("text"))) \
.withColumn("processed_date", F.current_timestamp())
# Cache because we'll use this 3 times
# MEMORY_AND_DISK spills to disk if memory is tight
cleaned_df.persist(StorageLevel.MEMORY_AND_DISK)
# Force materialization
cleaned_df.count()
# Stage 2: Multiple aggregations on cached data
daily_stats = cleaned_df.groupBy("date").agg(F.count("*").alias("count"))
user_stats = cleaned_df.groupBy("user_id").agg(F.sum("amount").alias("total"))
product_stats = cleaned_df.groupBy("product_id").agg(F.avg("rating").alias("avg_rating"))
# Write outputs
daily_stats.write.parquet("s3://bucket/daily_stats")
user_stats.write.parquet("s3://bucket/user_stats")
product_stats.write.parquet("s3://bucket/product_stats")
# CRITICAL: Unpersist when done
cleaned_df.unpersist()
return "Pipeline complete"
Storage level selection:
MEMORY_ONLY: Fastest, but recomputes if evicted. Use for data that’s cheap to recompute.MEMORY_AND_DISK: Spills to disk when memory is full. Default choice for most cases.DISK_ONLY: Slowest, but guaranteed to persist. Use for very large datasets.MEMORY_ONLY_SER: Serialized storage uses less memory but requires CPU for deserialization.
Monitoring and Debugging Tools
Enable GC logging to diagnose memory pressure:
spark = SparkSession.builder \
.config("spark.executor.extraJavaOptions",
"-XX:+PrintGCDetails -XX:+PrintGCTimeStamps") \
.getOrCreate()
OOM Troubleshooting Checklist:
- Check Spark UI Stages tab for skewed tasks (max duration » median)
- Review Executors tab for memory usage patterns
- Search logs for
OutOfMemoryErrorto identify driver vs. executor - Look for
collect(),toPandas(), orbroadcast()on large data - Verify shuffle partition count matches data volume
- Confirm cached DataFrames are unpersisted after use
- Check for Cartesian joins or exploding arrays
When all else fails, enable adaptive query execution in Spark 3.0+:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
AQE automatically handles partition coalescing and skew joins at runtime, eliminating much of the manual tuning described above. It’s the single most impactful configuration for Spark 3.x deployments.