PySpark - Memory Error Troubleshooting Guide

PySpark's memory model confuses even experienced engineers because it spans two runtimes: the JVM and Python. Before troubleshooting any memory error, you need to understand where memory lives.

Key Insights

  • PySpark memory errors stem from three distinct memory spaces—JVM heap, off-heap, and Python worker memory—and fixing them requires understanding which space is exhausted
  • Most memory issues trace back to data skew, improper partitioning, or collecting too much data to the driver; configuration tuning alone rarely solves the root cause
  • Spark 3.x’s Adaptive Query Execution eliminates many manual optimizations, but you still need to understand the fundamentals when AQE can’t save you

Understanding PySpark Memory Architecture

PySpark’s memory model confuses even experienced engineers because it spans two runtimes: the JVM and Python. Before troubleshooting any memory error, you need to understand where memory lives.

The driver is your application’s control center. It maintains the SparkContext, coordinates executors, and collects results. When you call collect() or toPandas(), data flows here. Driver memory errors typically mean you’re pulling too much data back.

Executors do the heavy lifting. Each executor runs on a worker node and processes data partitions. Executor memory splits into several regions: execution memory (shuffles, joins, sorts), storage memory (cached data), and overhead (JVM internals, Python workers).

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("MemoryConfig") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "8g") \
    .config("spark.executor.memoryOverhead", "2g") \
    .config("spark.memory.fraction", "0.6") \
    .config("spark.memory.storageFraction", "0.5") \
    .config("spark.python.worker.memory", "512m") \
    .getOrCreate()

The memoryOverhead setting is critical for PySpark. Python workers run outside the JVM heap, and this overhead allocation covers them. If you’re using pandas UDFs or Arrow-based operations, increase this value.

Common Memory Error Types and Their Causes

Different errors point to different problems. Learn to read them.

OutOfMemoryError: Java heap space means the JVM ran out of heap memory. This happens during shuffles when intermediate data exceeds available memory, or when you cache more data than fits in storage memory.

OutOfMemoryError: GC overhead limit exceeded indicates the JVM is spending more than 98% of time in garbage collection while recovering less than 2% of heap. Your application is thrashing—too many objects, too little memory.

Python MemoryError occurs in the Python worker process. This typically happens with pandas UDFs processing partitions larger than available Python memory, or when using toPandas() on large datasets.

Exit code 137 means the container was killed by YARN or Kubernetes for exceeding memory limits. The total memory usage (JVM heap + overhead + Python) exceeded the container allocation.

# How to find these errors in logs
# In Spark UI: Go to Executors tab, check "Logs" column for failed executors
# In YARN: yarn logs -applicationId <app_id>
# In Kubernetes: kubectl logs <pod-name>

# Programmatically check for failed stages
def check_failed_stages(spark):
    sc = spark.sparkContext
    status = sc.statusTracker()
    for stage_id in status.getActiveStageIds():
        stage_info = status.getStageInfo(stage_id)
        if stage_info:
            print(f"Stage {stage_id}: {stage_info.numFailedTasks} failed tasks")

Diagnosing Memory Issues

The Spark UI is your primary diagnostic tool. Navigate to the Storage tab to see cached RDDs and their memory consumption. The Executors tab shows memory usage per executor—look for executors consistently hitting limits.

Use explain() to understand what Spark plans to do with your data:

df = spark.read.parquet("s3://bucket/large-dataset/")
result = df.groupBy("user_id").agg({"amount": "sum"})

# Extended explain shows memory estimates
result.explain(mode="extended")

# In Spark 3.x, use formatted for readable output
result.explain(mode="formatted")

Check partition sizes to identify imbalances:

from pyspark.sql.functions import spark_partition_id, count

def analyze_partitions(df, name="DataFrame"):
    partition_stats = df.withColumn("partition_id", spark_partition_id()) \
        .groupBy("partition_id") \
        .agg(count("*").alias("row_count"))
    
    stats = partition_stats.agg(
        {"row_count": "min", "row_count": "max", "row_count": "avg"}
    ).collect()[0]
    
    print(f"{name} partition analysis:")
    print(f"  Partitions: {df.rdd.getNumPartitions()}")
    print(f"  Min rows: {stats[0]}, Max rows: {stats[1]}, Avg rows: {stats[2]:.0f}")
    
    # Skew ratio > 10 indicates problems
    skew_ratio = stats[1] / max(stats[2], 1)
    print(f"  Skew ratio: {skew_ratio:.2f}")
    return partition_stats

analyze_partitions(df, "Raw input")

Configuration-Based Solutions

Configuration tuning addresses symptoms, not causes. That said, proper configuration prevents many issues.

Start with executor memory. A good rule: set spark.executor.memory to use 75% of available container memory, leaving 25% for overhead. For PySpark with pandas UDFs, increase overhead to 30-40%.

from pyspark.sql import SparkSession

def create_optimized_session(app_name, executor_memory_gb=8, num_executors=10):
    overhead_mb = int(executor_memory_gb * 1024 * 0.3)  # 30% overhead
    
    return SparkSession.builder \
        .appName(app_name) \
        .config("spark.executor.memory", f"{executor_memory_gb}g") \
        .config("spark.executor.memoryOverhead", f"{overhead_mb}m") \
        .config("spark.executor.cores", "4") \
        .config("spark.executor.instances", str(num_executors)) \
        .config("spark.driver.memory", "4g") \
        .config("spark.driver.maxResultSize", "2g") \
        .config("spark.sql.shuffle.partitions", "200") \
        .config("spark.memory.fraction", "0.6") \
        .config("spark.memory.storageFraction", "0.5") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
        .config("spark.sql.adaptive.skewJoin.enabled", "true") \
        .getOrCreate()

Set spark.sql.shuffle.partitions based on your data size. The default of 200 is often wrong. Aim for partitions between 100MB and 200MB each. For a 100GB shuffle, that’s 500-1000 partitions.

Code-Level Optimizations

Code changes beat configuration changes. Here’s what actually moves the needle.

Never use collect() on large datasets. Use take(), head(), or toLocalIterator() instead:

# Bad: Pulls entire dataset to driver
all_data = df.collect()

# Better: Get sample
sample = df.take(1000)

# Best for large iterations: Process in batches
for row in df.toLocalIterator():
    process_row(row)

Use broadcast joins for small tables. When joining a large table with a small lookup table, broadcast the small one:

from pyspark.sql.functions import broadcast

large_df = spark.read.parquet("s3://bucket/transactions/")  # 100GB
small_df = spark.read.parquet("s3://bucket/products/")      # 50MB

# Without broadcast: Shuffle both datasets
result_slow = large_df.join(small_df, "product_id")

# With broadcast: Only shuffle large dataset
result_fast = large_df.join(broadcast(small_df), "product_id")

Repartition strategically. Repartition before expensive operations, coalesce after filters:

# After filtering, reduce partitions without shuffle
filtered = df.filter(df.status == "active")  # 90% filtered out
filtered = filtered.coalesce(50)  # Reduce from 200 to 50 partitions

# Before join on skewed key, repartition to distribute load
df_repartitioned = df.repartition(200, "join_key")

Handling Data Skew

Data skew kills PySpark jobs. When one partition has 100x more data than others, one executor does all the work while others sit idle—then crashes.

Identify skewed keys first:

from pyspark.sql.functions import count, col

def find_skewed_keys(df, key_column, threshold_ratio=10):
    key_counts = df.groupBy(key_column).agg(count("*").alias("cnt"))
    stats = key_counts.agg({"cnt": "avg"}).collect()[0][0]
    
    skewed = key_counts.filter(col("cnt") > stats * threshold_ratio)
    return skewed.orderBy(col("cnt").desc())

skewed_keys = find_skewed_keys(df, "customer_id")
skewed_keys.show(10)

Salting distributes skewed keys across multiple partitions:

from pyspark.sql.functions import concat, lit, rand, floor, explode, array

def salted_join(large_df, small_df, join_key, num_salts=10):
    # Add salt to large dataframe
    large_salted = large_df.withColumn(
        "salt", floor(rand() * num_salts).cast("int")
    ).withColumn(
        "salted_key", concat(col(join_key), lit("_"), col("salt"))
    )
    
    # Explode small dataframe to match all salts
    salt_array = array([lit(i) for i in range(num_salts)])
    small_exploded = small_df.withColumn("salt", explode(salt_array)) \
        .withColumn("salted_key", concat(col(join_key), lit("_"), col("salt")))
    
    # Join on salted key
    result = large_salted.join(small_exploded, "salted_key") \
        .drop("salt", "salted_key")
    
    return result

In Spark 3.x, enable Adaptive Query Execution (AQE) to handle skew automatically:

spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")

Prevention and Monitoring Best Practices

Build memory awareness into your development workflow.

Create a monitoring wrapper for production jobs:

import time
import logging
from functools import wraps

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("spark_memory")

def monitor_memory(description):
    def decorator(func):
        @wraps(func)
        def wrapper(spark, *args, **kwargs):
            spark.sparkContext.setJobDescription(description)
            
            start = time.time()
            try:
                result = func(spark, *args, **kwargs)
                duration = time.time() - start
                
                # Log executor metrics
                metrics = spark.sparkContext.statusTracker()
                logger.info(f"{description} completed in {duration:.2f}s")
                
                return result
            except Exception as e:
                logger.error(f"{description} failed: {str(e)}")
                raise
        return wrapper
    return decorator

@monitor_memory("Load and transform customer data")
def process_customers(spark):
    df = spark.read.parquet("s3://bucket/customers/")
    return df.groupBy("region").count()

Code review checklist for memory-safe PySpark:

  1. No collect() without size limits
  2. Broadcast joins for tables under 100MB
  3. Explicit partition counts after filters
  4. No Python UDFs on large datasets (use pandas UDFs with Arrow)
  5. Caching only when data is reused multiple times
  6. unpersist() called after cached data is no longer needed

Memory errors in PySpark are frustrating but predictable. Understand the memory model, diagnose with the right tools, and fix the root cause rather than throwing more memory at the problem. Your cluster costs will thank you.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.