Apache Spark - When to Cache vs Persist vs Checkpoint

Spark's lazy evaluation is both its greatest strength and a subtle performance trap. When you chain transformations, Spark builds a Directed Acyclic Graph (DAG) representing the lineage of your data....

Key Insights

  • Cache and persist store data in memory or disk but preserve lineage, meaning Spark can recompute lost partitions; checkpoint writes to reliable storage and truncates lineage completely, trading storage space for guaranteed recovery.
  • Use persist(MEMORY_AND_DISK_SER) as your default for production workloads—it handles memory pressure gracefully and serialized storage reduces GC overhead significantly.
  • Always checkpoint after caching in iterative algorithms; checkpointing an uncached RDD forces a full recomputation, doubling your work instead of saving it.

The Cost of Recomputation

Spark’s lazy evaluation is both its greatest strength and a subtle performance trap. When you chain transformations, Spark builds a Directed Acyclic Graph (DAG) representing the lineage of your data. Nothing executes until you call an action like count() or collect(). This design enables powerful optimizations, but it also means that every action triggers a full recomputation from the source.

Consider an iterative machine learning algorithm that runs 100 iterations. Without intermediate storage, Spark would read your source data and apply all transformations 100 times. For a 10-stage pipeline reading from S3, that’s 1,000 stages of redundant work.

Spark provides three mechanisms to break this pattern: cache(), persist(), and checkpoint(). Each serves a distinct purpose, and choosing wrong can cost you hours of compute time—or crash your cluster entirely.

Cache: The Simple In-Memory Default

The cache() method is syntactic sugar for persist(StorageLevel.MEMORY_ONLY). It stores your DataFrame or RDD partitions in executor memory as deserialized Java objects.

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("CacheExample").getOrCreate()

# Load and filter a large dataset
transactions = spark.read.parquet("s3://data-lake/transactions/")
high_value = transactions.filter(transactions.amount > 10000)

# Cache the filtered result for reuse
high_value.cache()

# First action materializes the cache
total_count = high_value.count()

# Subsequent actions read from memory
by_region = high_value.groupBy("region").sum("amount")
by_category = high_value.groupBy("category").avg("amount")

by_region.show()
by_category.show()

This pattern works well when your filtered dataset fits comfortably in aggregate executor memory. The key word is “comfortably”—Spark needs memory for execution too, not just storage.

When memory runs short with MEMORY_ONLY, Spark doesn’t spill to disk. It evicts cached partitions using LRU and recomputes them on demand. This behavior is fine for short lineages but catastrophic for complex pipelines. Your job won’t fail; it will just slow down mysteriously as partitions get evicted and recomputed repeatedly.

Use cache() when:

  • Your data fits in 60% or less of available executor memory
  • The lineage is short (under 10 stages)
  • You’re in development and want simple semantics

Persist: Flexible Storage Levels

The persist() method accepts a StorageLevel argument that controls exactly where and how Spark stores your data.

from pyspark import StorageLevel

# Large dataset that won't fit entirely in memory
user_sessions = spark.read.parquet("s3://logs/sessions/")
enriched = (user_sessions
    .join(user_profiles, "user_id")
    .join(product_catalog, "product_id")
    .withColumn("session_value", calculate_value_udf("events")))

# Serialize and allow disk spillover
enriched.persist(StorageLevel.MEMORY_AND_DISK_SER)

# Materialize
enriched.count()

# Multiple downstream consumers
enriched.write.parquet("s3://output/enriched_sessions/")
enriched.groupBy("user_segment").agg(collect_list("events")).show()

Here’s what each storage level actually does:

Level Memory Disk Serialized Replicated
MEMORY_ONLY Yes No No No
MEMORY_AND_DISK Yes Yes No No
MEMORY_ONLY_SER Yes No Yes No
MEMORY_AND_DISK_SER Yes Yes Yes No
DISK_ONLY No Yes Yes No
*_2 variants * * * Yes (2 copies)

Serialized storage (_SER variants) trades CPU for memory efficiency. Deserialized Java objects carry significant overhead—often 2-4x the raw data size. Serialization with Kryo compresses this substantially but requires CPU cycles to serialize and deserialize.

For production workloads, MEMORY_AND_DISK_SER is almost always the right choice. It handles memory pressure gracefully, reduces GC overhead, and won’t silently recompute partitions.

// Scala example with explicit storage level
import org.apache.spark.storage.StorageLevel

val features = rawData
  .transform(extractFeatures)
  .transform(normalizeFeatures)
  .persist(StorageLevel.MEMORY_AND_DISK_SER)

// Force materialization before iterative training
features.count()

// Training loop reuses cached features
(1 to 100).foreach { iteration =>
  val gradients = features.map(computeGradient(weights, _))
  weights = updateWeights(weights, gradients.reduce(_ + _))
}

Checkpoint: Breaking the Lineage

Checkpointing is fundamentally different from caching. While cache and persist store data but preserve lineage, checkpoint writes data to reliable storage (typically HDFS or S3) and truncates the lineage graph entirely.

# Configure checkpoint directory (must be reliable distributed storage)
spark.sparkContext.setCheckpointDir("hdfs://cluster/spark-checkpoints/")

def train_model(data, iterations):
    current = data
    
    for i in range(iterations):
        # Compute gradient update
        current = current.map(lambda x: update_weights(x))
        
        # Checkpoint every 10 iterations to truncate lineage
        if i % 10 == 0:
            current.cache()  # Cache first!
            current.checkpoint()
            current.count()  # Materialize both
    
    return current

Why does lineage truncation matter? In iterative algorithms, lineage grows with each iteration. After 100 iterations of PageRank, your lineage graph has 100+ stages. If a partition fails, Spark must recompute from the beginning. Worse, the lineage metadata itself consumes driver memory and can cause OOM errors.

Checkpointing solves this by writing data to disk and replacing the lineage with “read from checkpoint file.” Recovery now means reading one file, not recomputing 100 transformations.

Spark offers two checkpoint modes:

Reliable checkpointing writes to distributed storage like HDFS. It survives executor failures, driver restarts, and cluster termination. Use this for production jobs.

Local checkpointing (localCheckpoint()) writes to executor local disk. It’s faster but doesn’t survive executor failures. Use this only when you need lineage truncation for memory reasons but can tolerate recomputation on failure.

# Local checkpoint for lineage truncation without reliability guarantees
streaming_aggregates.localCheckpoint()

Decision Framework: Choosing the Right Strategy

Here’s a concrete example showing all three approaches in a PageRank implementation:

def pagerank_naive(edges, iterations):
    """No caching - recomputes everything each iteration"""
    ranks = edges.groupBy("src").count().withColumn("rank", lit(1.0))
    
    for _ in range(iterations):
        contributions = edges.join(ranks, edges.src == ranks.src)
        ranks = contributions.groupBy("dst").agg(sum("rank") * 0.85 + 0.15)
    
    return ranks  # Lineage has iterations * stages

def pagerank_cached(edges, iterations):
    """Cached edges - still has lineage growth"""
    edges.cache()
    edges.count()
    
    ranks = edges.groupBy("src").count().withColumn("rank", lit(1.0))
    ranks.persist(StorageLevel.MEMORY_AND_DISK_SER)
    
    for _ in range(iterations):
        old_ranks = ranks
        contributions = edges.join(ranks, edges.src == ranks.src)
        ranks = contributions.groupBy("dst").agg(sum("rank") * 0.85 + 0.15)
        ranks.persist(StorageLevel.MEMORY_AND_DISK_SER)
        ranks.count()
        old_ranks.unpersist()
    
    return ranks

def pagerank_checkpointed(edges, iterations, checkpoint_interval=10):
    """Checkpointed - bounded lineage and failure recovery"""
    spark.sparkContext.setCheckpointDir("hdfs://cluster/checkpoints/")
    
    edges.cache()
    edges.count()
    
    ranks = edges.groupBy("src").count().withColumn("rank", lit(1.0))
    
    for i in range(iterations):
        contributions = edges.join(ranks, edges.src == ranks.src)
        ranks = contributions.groupBy("dst").agg(sum("rank") * 0.85 + 0.15)
        
        if i % checkpoint_interval == 0:
            ranks.cache()
            ranks.checkpoint()
            ranks.count()
    
    return ranks

Use this decision tree:

  1. Will the data be reused? No → Don’t cache at all
  2. Does it fit in memory? Yes → cache() or persist(MEMORY_ONLY_SER)
  3. Is memory tight? Yes → persist(MEMORY_AND_DISK_SER)
  4. Is lineage growing unbounded? Yes → Add checkpointing
  5. Must survive driver restart? Yes → Reliable checkpoint to HDFS/S3

Common Pitfalls and Best Practices

Forgetting to unpersist: Cached data stays in memory until the SparkContext ends or you explicitly remove it. In long-running applications, this causes memory leaks.

# Bad: Memory leak in a loop
for date in date_range:
    daily_data = load_data(date)
    daily_data.cache()
    process(daily_data)
    # daily_data stays cached forever

# Good: Explicit cleanup
for date in date_range:
    daily_data = load_data(date)
    daily_data.cache()
    daily_data.count()
    process(daily_data)
    daily_data.unpersist()

Checkpointing without caching first: This is the most common mistake. Checkpoint triggers a job to write data. Without caching, it recomputes the entire lineage. Then when you use the checkpointed data, Spark reads from the checkpoint file. You’ve done the work twice.

# Bad: Double computation
expensive_result = long_pipeline(data)
expensive_result.checkpoint()  # Computes and writes
expensive_result.count()       # Reads from checkpoint

# Good: Cache before checkpoint
expensive_result = long_pipeline(data)
expensive_result.cache()
expensive_result.checkpoint()
expensive_result.count()  # Computes once, writes to cache AND checkpoint

Over-caching: Not everything needs caching. If you use a DataFrame exactly once, caching adds overhead with no benefit. Profile first, cache second.

Conclusion

Monitor your caching strategy through the Spark UI’s Storage tab. It shows exactly what’s cached, how much memory it consumes, and what fraction is actually in memory versus spilled to disk.

The right choice depends on your specific workload: cache for simplicity, persist for flexibility, and checkpoint for reliability. Start with persist(MEMORY_AND_DISK_SER) for most production use cases, add checkpointing for iterative algorithms, and always clean up after yourself with unpersist().

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.