Apache Spark - Caching Strategies (MEMORY_ONLY, MEMORY_AND_DISK, etc.)
Spark's lazy evaluation model means transformations aren't executed until an action triggers computation. Without caching, every action recomputes the entire lineage from scratch. For iterative...
Key Insights
cache()is just a shortcut forpersist(StorageLevel.MEMORY_AND_DISK)—understanding the full range of storage levels lets you optimize for your specific memory, speed, and fault-tolerance requirements.- Serialized storage levels (
_SERvariants) can reduce memory usage by 2-5x at the cost of CPU overhead, making them essential when caching large datasets on memory-constrained clusters. - Blindly caching everything hurts performance—cache only DataFrames that are accessed multiple times, and always call
unpersist()when you’re done to free resources for other operations.
Introduction to Spark Caching
Spark’s lazy evaluation model means transformations aren’t executed until an action triggers computation. Without caching, every action recomputes the entire lineage from scratch. For iterative algorithms or interactive analysis where you hit the same DataFrame multiple times, this recomputation becomes a massive bottleneck.
Caching stores the computed partitions in memory, on disk, or both, so subsequent actions skip the expensive recomputation. The performance difference can be dramatic—turning minute-long queries into sub-second responses.
The API offers two methods: cache() and persist(). They’re often confused, but the distinction is simple:
from pyspark.sql import SparkSession
from pyspark.storagelevel import StorageLevel
spark = SparkSession.builder.appName("CachingDemo").getOrCreate()
# These two are equivalent
df.cache()
df.persist(StorageLevel.MEMORY_AND_DISK)
# persist() lets you specify the storage level
df.persist(StorageLevel.MEMORY_ONLY)
df.persist(StorageLevel.DISK_ONLY)
Here’s a concrete example showing the impact:
import time
# Create a moderately expensive DataFrame
raw_df = spark.read.parquet("s3://bucket/large-dataset/")
expensive_df = raw_df.groupBy("category").agg(
{"amount": "sum", "quantity": "avg", "price": "max"}
).filter("sum(amount) > 1000")
# First run without caching
start = time.time()
expensive_df.count()
print(f"First count (no cache): {time.time() - start:.2f}s")
start = time.time()
expensive_df.show(10)
print(f"Show (no cache): {time.time() - start:.2f}s")
# Now with caching
expensive_df.cache()
expensive_df.count() # Materializes the cache
start = time.time()
expensive_df.count()
print(f"Second count (cached): {time.time() - start:.2f}s")
start = time.time()
expensive_df.show(10)
print(f"Show (cached): {time.time() - start:.2f}s")
On a typical cluster, you’ll see the cached operations complete 10-100x faster, depending on the complexity of your transformations.
Understanding Storage Levels
The StorageLevel class controls exactly how Spark stores cached data. It’s defined by five boolean flags:
from pyspark.storagelevel import StorageLevel
# Inspect what each level actually means
print(StorageLevel.MEMORY_ONLY)
# StorageLevel(useDisk=False, useMemory=True, useOffHeap=False,
# deserialized=True, replication=1)
print(StorageLevel.MEMORY_AND_DISK_SER_2)
# StorageLevel(useDisk=True, useMemory=True, useOffHeap=False,
# deserialized=False, replication=2)
# You can even create custom levels (though rarely needed)
custom_level = StorageLevel(
useDisk=True,
useMemory=True,
useOffHeap=False,
deserialized=True,
replication=2
)
The flags break down as follows:
- useDisk: Store partitions on local disk when they don’t fit in memory
- useMemory: Store partitions in JVM heap memory
- useOffHeap: Store outside JVM heap using Tungsten’s memory management
- deserialized: Keep objects as Java/Python objects (True) or serialize them (False)
- replication: Number of cluster nodes to store each partition on
MEMORY_ONLY and MEMORY_ONLY_SER
MEMORY_ONLY is the fastest option when it works. Partitions stay as deserialized Java objects in the JVM heap, meaning zero serialization overhead on read. The catch: Java objects are memory hogs. A DataFrame that’s 1GB on disk might consume 4-5GB as deserialized objects.
from pyspark.storagelevel import StorageLevel
# Fast access, high memory usage
df.persist(StorageLevel.MEMORY_ONLY)
# Trigger caching and check memory consumption in Spark UI
df.count()
# Check storage level programmatically
print(f"Storage level: {df.storageLevel}")
print(f"Is cached: {df.is_cached}")
When memory is insufficient, Spark doesn’t spill to disk—it simply doesn’t cache the partitions that don’t fit. Those partitions get recomputed on every action. This “silent partial caching” catches many developers off guard.
MEMORY_ONLY_SER trades CPU for memory. Spark serializes partitions into compact byte arrays, typically achieving 2-5x compression:
# More compact, requires deserialization on read
df.persist(StorageLevel.MEMORY_ONLY_SER)
df.count()
# Compare memory usage in Spark UI Storage tab
# MEMORY_ONLY: ~4.2 GB
# MEMORY_ONLY_SER: ~1.1 GB (typical for columnar data)
Use MEMORY_ONLY_SER when:
- Your cluster is memory-constrained
- You’re caching large datasets
- Read latency isn’t critical (batch processing vs. interactive queries)
For DataFrames specifically, Spark’s Tungsten engine already uses an efficient columnar format, so the serialization overhead is lower than with RDDs.
MEMORY_AND_DISK Variants
MEMORY_AND_DISK is the default for cache() and the most practical choice for production workloads. Partitions that fit in memory stay there; overflow spills to local disk rather than being dropped.
from pyspark.storagelevel import StorageLevel
# Spill to disk when memory is full
large_df.persist(StorageLevel.MEMORY_AND_DISK)
# Force a scenario where spilling occurs
# (assuming executor memory is limited)
spark.conf.set("spark.executor.memory", "2g")
# Process a dataset larger than available memory
huge_df = spark.range(0, 500000000).withColumn(
"data", F.expr("uuid()")
)
huge_df.persist(StorageLevel.MEMORY_AND_DISK)
huge_df.count()
# Check Spark UI Storage tab:
# - Memory Size: 1.8 GB
# - Disk Size: 3.2 GB (spilled portion)
The serialized variant MEMORY_AND_DISK_SER combines the memory efficiency of serialization with disk spillover:
# Best of both worlds for large datasets
huge_df.persist(StorageLevel.MEMORY_AND_DISK_SER)
Disk I/O is acceptable when:
- SSDs are available (NVMe preferred)
- Recomputation cost exceeds disk read time
- You’re doing batch processing, not interactive queries
Avoid disk spillover when:
- Running on HDDs (seek time kills performance)
- Network-attached storage is your only option
- Sub-second response times are required
DISK_ONLY and OFF_HEAP Options
DISK_ONLY is rarely the right choice, but it has legitimate uses:
# Cache to disk only—useful for checkpointing expensive computations
# when memory is needed for other operations
checkpoint_df.persist(StorageLevel.DISK_ONLY)
Use DISK_ONLY when:
- You need to cache many DataFrames simultaneously
- Memory must be reserved for shuffle operations
- The cached data is accessed infrequently but recomputation is expensive
Off-heap storage moves data outside the JVM heap, avoiding garbage collection overhead:
# Configure off-heap memory first
spark = SparkSession.builder \
.appName("OffHeapDemo") \
.config("spark.memory.offHeap.enabled", "true") \
.config("spark.memory.offHeap.size", "4g") \
.getOrCreate()
# Now you can use off-heap storage
df.persist(StorageLevel.OFF_HEAP)
Off-heap benefits include:
- No GC pauses on cached data
- More predictable latency
- Better memory utilization for large heaps
The downside: more complex configuration and debugging. Stick with on-heap storage unless GC is a proven bottleneck.
Replication Strategies (_2 Variants)
Every storage level has a replicated variant (suffix _2) that stores each partition on two nodes:
# Replicate cached data for fault tolerance
critical_df.persist(StorageLevel.MEMORY_AND_DISK_2)
# If an executor dies, the partition is available on another node
# without recomputation
Replication doubles your storage requirements but provides:
- Faster recovery from executor failures
- No recomputation when nodes die
- Consistent performance during cluster instability
from pyspark.storagelevel import StorageLevel
# For mission-critical pipelines on spot instances
def cache_with_fault_tolerance(df, critical=False):
if critical:
return df.persist(StorageLevel.MEMORY_AND_DISK_SER_2)
return df.persist(StorageLevel.MEMORY_AND_DISK_SER)
# Use replication for data that's expensive to recompute
# and runs on preemptible/spot instances
model_features = expensive_feature_engineering(raw_data)
cache_with_fault_tolerance(model_features, critical=True)
Best Practices and Monitoring
Choose your storage level based on these factors:
-
Data size vs. available memory: If cached data fits in 60% of executor memory, use
MEMORY_ONLY. Otherwise, useMEMORY_AND_DISK_SER. -
Access frequency: Interactive queries benefit from deserialized storage. Batch jobs can tolerate serialization overhead.
-
Cluster stability: Use
_2variants on spot instances or unreliable hardware.
Always unpersist when done:
# Process data in stages, releasing cache as you go
stage1_df = raw_df.transform(stage1_processing)
stage1_df.cache()
stage1_df.count() # Materialize
stage2_df = stage1_df.transform(stage2_processing)
stage2_df.cache()
stage2_df.count() # Materialize
# Release stage1 cache—we don't need it anymore
stage1_df.unpersist()
# Continue processing with stage2_df
final_result = stage2_df.transform(final_processing)
final_result.write.parquet("output/")
stage2_df.unpersist()
Monitor cache effectiveness:
# Programmatic cache inspection
def inspect_cache(df, name="DataFrame"):
print(f"{name}:")
print(f" Cached: {df.is_cached}")
print(f" Storage Level: {df.storageLevel}")
# For detailed metrics, use the Spark UI or:
spark.catalog.isCached(table_name) # For registered tables
# Clear all cached data (useful in notebooks)
spark.catalog.clearCache()
The Spark UI Storage tab shows exactly what’s cached, how much memory and disk each DataFrame consumes, and which partitions are cached on which executors. Check it regularly—you’ll often find forgotten caches consuming resources.
The golden rule: cache DataFrames that are accessed multiple times and expensive to compute. Everything else should flow through without caching. When in doubt, benchmark both approaches on representative data.