Spark Scala - Cache and Persist
Spark's lazy evaluation model means transformations build up a lineage graph that gets executed only when you call an action. This is elegant for optimization, but it has a cost: every action...
Key Insights
cache()is simply an alias forpersist(StorageLevel.MEMORY_ONLY)—usepersist()when you need control over storage levels, especially for large datasets that won’t fit in memory.- Caching is lazy and does nothing until you trigger an action; forgetting this is the most common mistake developers make when debugging performance issues.
- Strategic caching of datasets used in iterative algorithms or multiple downstream actions can yield 10x+ performance improvements, but indiscriminate caching wastes cluster resources and can cause memory pressure.
Introduction to Data Caching in Spark
Spark’s lazy evaluation model means transformations build up a lineage graph that gets executed only when you call an action. This is elegant for optimization, but it has a cost: every action recomputes the entire lineage from scratch.
Consider a scenario where you load a large dataset, apply several transformations, then run multiple aggregations. Without caching, Spark re-reads the source data and re-applies every transformation for each aggregation. For iterative machine learning algorithms that might run hundreds of iterations on the same dataset, this becomes catastrophically expensive.
Caching breaks this pattern by storing intermediate results in memory (or disk) so subsequent actions can reuse them. The performance difference can be dramatic—I’ve seen jobs go from hours to minutes simply by adding a single cache() call in the right place.
Cache vs Persist: Understanding the Difference
Let’s clear up the confusion immediately: cache() and persist() are not different operations. The cache() method is literally shorthand for persist(StorageLevel.MEMORY_ONLY).
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
val spark = SparkSession.builder()
.appName("CacheDemo")
.getOrCreate()
import spark.implicits._
val rawData = spark.read.parquet("/data/transactions")
val processedData = rawData
.filter($"amount" > 100)
.withColumn("category", categorizeUDF($"merchant"))
// These two lines are functionally identical
val cachedData = processedData.cache()
val persistedData = processedData.persist(StorageLevel.MEMORY_ONLY)
Use cache() when you’re confident your dataset fits in memory and you want clean, readable code. Use persist() when you need a different storage level or want to be explicit about your caching strategy.
My recommendation: default to persist() with an explicit storage level. It makes your intentions clear and forces you to think about memory constraints.
Storage Levels Explained
Spark provides several storage levels, each with different trade-offs between memory usage, CPU overhead, and fault tolerance.
import org.apache.spark.storage.StorageLevel
// Memory only - fastest, but fails if data doesn't fit
df.persist(StorageLevel.MEMORY_ONLY)
// Spill to disk if memory is full - safer for large datasets
df.persist(StorageLevel.MEMORY_AND_DISK)
// Disk only - slowest, but handles arbitrarily large data
df.persist(StorageLevel.DISK_ONLY)
// Serialized variants - less memory, more CPU
df.persist(StorageLevel.MEMORY_ONLY_SER)
df.persist(StorageLevel.MEMORY_AND_DISK_SER)
// Replicated variants - fault tolerant, 2x storage
df.persist(StorageLevel.MEMORY_ONLY_2)
df.persist(StorageLevel.MEMORY_AND_DISK_2)
Here’s how to think about each level:
MEMORY_ONLY: Default for cache(). Stores deserialized Java objects in JVM heap. Fastest access but highest memory footprint. Partitions that don’t fit are recomputed on demand.
MEMORY_AND_DISK: My go-to for production workloads. Spills partitions to disk when memory is exhausted. Slightly slower than pure memory but far more reliable.
DISK_ONLY: Rarely useful. If your data doesn’t fit in memory at all, you’re probably better off optimizing your transformations or increasing cluster resources.
MEMORY_ONLY_SER / MEMORY_AND_DISK_SER: Stores serialized bytes instead of objects. Uses 2-5x less memory but requires serialization/deserialization on every access. Use Kryo serialization for best results:
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
spark.conf.set("spark.kryo.registrationRequired", "false")
val largeDataset = spark.read.parquet("/data/huge_table")
.persist(StorageLevel.MEMORY_AND_DISK_SER)
Replicated variants (_2 suffix): Store two copies across different nodes. Use these when recomputation is extremely expensive and you need fault tolerance within a single job. Rarely necessary for most workloads.
When to Cache: Best Practices
Not everything should be cached. Here are the scenarios where caching provides real value:
1. Multiple actions on the same dataset:
val aggregatedSales = salesData
.groupBy("region", "product_category")
.agg(sum("revenue").as("total_revenue"))
.persist(StorageLevel.MEMORY_AND_DISK)
// First action - triggers computation and caching
val topRegions = aggregatedSales
.groupBy("region")
.agg(sum("total_revenue").as("region_total"))
.orderBy($"region_total".desc)
.limit(10)
.collect()
// Second action - reads from cache
val topCategories = aggregatedSales
.groupBy("product_category")
.agg(sum("total_revenue").as("category_total"))
.orderBy($"category_total".desc)
.limit(10)
.collect()
2. Iterative algorithms (ML training, graph processing):
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.LogisticRegression
val trainingData = spark.read.parquet("/data/features")
.filter($"label".isNotNull)
.persist(StorageLevel.MEMORY_AND_DISK)
// Force caching before iteration
trainingData.count()
val featureCols = Array("feature1", "feature2", "feature3")
val assembler = new VectorAssembler()
.setInputCols(featureCols)
.setOutputCol("features")
val preparedData = assembler.transform(trainingData)
.select("features", "label")
.persist(StorageLevel.MEMORY_AND_DISK)
// Multiple model iterations reuse cached data
val regParams = Seq(0.01, 0.1, 1.0)
val models = regParams.map { regParam =>
new LogisticRegression()
.setRegParam(regParam)
.setMaxIter(100)
.fit(preparedData)
}
3. Expensive transformations upstream of a branch point:
val expensiveResult = rawData
.join(lookupTable, "key") // Expensive shuffle
.withColumn("derived", complexUDF($"value")) // Expensive UDF
.persist(StorageLevel.MEMORY_AND_DISK)
// Branch 1: Analytics pipeline
val analyticsOutput = expensiveResult.groupBy("category").count()
// Branch 2: Export pipeline
val exportOutput = expensiveResult.select("id", "derived", "timestamp")
Anti-patterns to avoid:
- Caching immediately after reading data (cache after transformations, not before)
- Caching small datasets that recompute in milliseconds
- Caching datasets used only once
- Caching before shuffle operations (the shuffle output is often what you want to cache)
Monitoring and Managing Cached Data
The Spark UI’s Storage tab shows all cached RDDs and DataFrames. Check it regularly to understand your memory usage.
// Check if a DataFrame is cached and at what level
println(s"Storage level: ${df.storageLevel}")
// Output: Storage level: StorageLevel(memory, deserialized, 1 replicas)
// Check if actually cached (not just marked for caching)
println(s"Is cached: ${df.storageLevel != StorageLevel.NONE}")
When you’re done with cached data, release it explicitly:
// Lazy unpersist - marks for removal, actual eviction happens asynchronously
df.unpersist()
// Eager unpersist - blocks until memory is freed
df.unpersist(blocking = true)
Always unpersist when you’re done, especially in long-running applications or notebooks. Memory leaks from forgotten cached datasets are a common cause of job failures.
Common Pitfalls and Troubleshooting
Pitfall 1: Forgetting that caching is lazy
This is the most common mistake:
val df = spark.read.parquet("/data/large_table")
.filter($"status" === "active")
.cache()
// WRONG: Timing this won't tell you anything useful
val startTime = System.currentTimeMillis()
val result = df.groupBy("category").count().collect()
val duration = System.currentTimeMillis() - startTime
// This includes both caching AND the aggregation
// RIGHT: Trigger caching explicitly first
df.count() // Forces materialization into cache
val startTime2 = System.currentTimeMillis()
val result2 = df.groupBy("region").count().collect()
val duration2 = System.currentTimeMillis() - startTime2
// This measures only the aggregation on cached data
Pitfall 2: Caching before shuffle boundaries
// Inefficient - caching before the shuffle
val cached = rawData.cache()
val result = cached.groupBy("key").agg(sum("value"))
// Better - cache after the shuffle
val result = rawData.groupBy("key").agg(sum("value")).cache()
Pitfall 3: Memory pressure from over-caching
When you cache more data than fits in memory, Spark evicts older cached partitions using LRU. This can cause thrashing where partitions are repeatedly evicted and recomputed.
Monitor the Spark UI for “Fraction Cached” below 100%. If you see this, either reduce what you’re caching or switch to MEMORY_AND_DISK.
Summary and Quick Reference
| Scenario | Recommended Storage Level |
|---|---|
| Small dataset, multiple actions | MEMORY_ONLY |
| Large dataset, multiple actions | MEMORY_AND_DISK |
| Very large dataset, memory constrained | MEMORY_AND_DISK_SER |
| Extremely expensive recomputation | MEMORY_AND_DISK_2 |
| Dataset used only once | Don’t cache |
Decision checklist:
- Is this dataset used by multiple actions? If no, don’t cache.
- Does it fit in memory? If yes, use
MEMORY_ONLY. If no, useMEMORY_AND_DISK. - Is memory still tight? Add
_SERsuffix and enable Kryo. - Did you trigger an action after caching? If not, your cache is empty.
- Did you unpersist when done? If not, you’re leaking memory.
Caching is one of the highest-leverage optimizations in Spark, but it requires intentionality. Cache strategically, monitor aggressively, and clean up after yourself.