PySpark - RDD Persistence (cache, persist)
• RDD persistence stores intermediate results in memory or disk to avoid recomputation, critical for iterative algorithms and interactive analysis where the same dataset is accessed multiple times
Key Insights
• RDD persistence stores intermediate results in memory or disk to avoid recomputation, critical for iterative algorithms and interactive analysis where the same dataset is accessed multiple times
• cache() is a shorthand for persist(StorageLevel.MEMORY_ONLY), while persist() offers granular control over storage locations including memory, disk, off-heap, and serialization options
• Unpersisting RDDs explicitly with unpersist() prevents memory bloat in long-running applications, as Spark’s LRU eviction may not always align with your application’s memory requirements
Understanding RDD Persistence Fundamentals
PySpark’s lazy evaluation model recomputes RDDs from their lineage graph each time an action is called. For transformations applied repeatedly on the same dataset, this recomputation becomes a performance bottleneck. Persistence breaks this pattern by materializing RDD data in memory, disk, or both.
Consider a typical machine learning scenario where you preprocess data once but use it in multiple training iterations:
from pyspark import SparkContext, StorageLevel
sc = SparkContext("local[*]", "PersistenceDemo")
# Load and transform data
raw_data = sc.textFile("logs.txt")
cleaned_data = raw_data.filter(lambda x: len(x) > 0) \
.map(lambda x: x.strip().lower()) \
.filter(lambda x: not x.startswith("#"))
# Without persistence, this lineage executes twice
count1 = cleaned_data.count()
count2 = cleaned_data.count()
Without persistence, Spark reads logs.txt twice, applies all transformations twice. With persistence:
# Cache the cleaned data
cleaned_data.cache()
# First action materializes and stores the RDD
count1 = cleaned_data.count()
# Second action reads from cache
count2 = cleaned_data.count()
Storage Levels and Memory Trade-offs
The persist() method accepts a StorageLevel parameter that controls where and how data is stored. Each level represents different trade-offs between CPU, memory, and I/O:
from pyspark import StorageLevel
# Memory only - fastest but risky for large datasets
rdd.persist(StorageLevel.MEMORY_ONLY)
# Memory and disk - spills to disk when memory is full
rdd.persist(StorageLevel.MEMORY_AND_DISK)
# Serialized in memory - saves space, costs CPU
rdd.persist(StorageLevel.MEMORY_ONLY_SER)
# Disk only - slowest but handles any size
rdd.persist(StorageLevel.DISK_ONLY)
# Replicated variants for fault tolerance
rdd.persist(StorageLevel.MEMORY_AND_DISK_2)
Here’s a practical comparison with real data:
import time
# Generate test dataset
large_rdd = sc.parallelize(range(10000000), 100) \
.map(lambda x: (x, x * x, x * x * x))
# Test MEMORY_ONLY
large_rdd.persist(StorageLevel.MEMORY_ONLY)
start = time.time()
large_rdd.count()
first_run = time.time() - start
start = time.time()
large_rdd.count()
cached_run = time.time() - start
print(f"First run: {first_run:.2f}s, Cached run: {cached_run:.2f}s")
# Clean up and test MEMORY_ONLY_SER
large_rdd.unpersist()
large_rdd.persist(StorageLevel.MEMORY_ONLY_SER)
large_rdd.count() # Materialize
# Check storage info
print(f"Storage Level: {large_rdd.getStorageLevel()}")
Practical Patterns for Caching Decisions
Not every RDD benefits from persistence. Cache when:
- Iterative algorithms access the same dataset multiple times
- Branching workflows derive multiple RDDs from a common ancestor
- Interactive analysis requires repeated queries on the same data
# Pattern 1: Iterative algorithm (PageRank-style)
links = sc.parallelize([
("page1", ["page2", "page3"]),
("page2", ["page1"]),
("page3", ["page1", "page2"])
])
ranks = sc.parallelize([
("page1", 1.0),
("page2", 1.0),
("page3", 1.0)
])
# Cache links since it's used in every iteration
links.cache()
for iteration in range(10):
contribs = links.join(ranks).flatMap(
lambda x: [(dest, x[1][1] / len(x[1][0])) for dest in x[1][0]]
)
ranks = contribs.reduceByKey(lambda x, y: x + y).mapValues(lambda v: 0.15 + 0.85 * v)
# Pattern 2: Branching workflow
base_data = sc.textFile("transactions.csv").map(parse_transaction)
base_data.cache()
# Multiple analyses from same base
revenue_by_region = base_data.map(lambda x: (x['region'], x['amount'])) \
.reduceByKey(lambda a, b: a + b)
revenue_by_product = base_data.map(lambda x: (x['product'], x['amount'])) \
.reduceByKey(lambda a, b: a + b)
top_customers = base_data.map(lambda x: (x['customer'], x['amount'])) \
.reduceByKey(lambda a, b: a + b) \
.takeOrdered(10, key=lambda x: -x[1])
Avoid caching when:
# Anti-pattern: Caching RDDs used only once
data = sc.textFile("data.txt")
data.cache() # Wasteful
result = data.filter(lambda x: "error" in x).count()
# Anti-pattern: Caching very small datasets
tiny_rdd = sc.parallelize([1, 2, 3, 4, 5])
tiny_rdd.cache() # Overhead exceeds benefit
# Anti-pattern: Caching before expensive shuffles
large_dataset.cache()
result = large_dataset.groupByKey().mapValues(list) # Cache after groupByKey instead
Managing Persistence Lifecycle
Explicit memory management prevents resource leaks in long-running applications:
def process_batch(batch_id, data_path):
# Load and cache batch-specific data
batch_data = sc.textFile(f"{data_path}/batch_{batch_id}.txt")
batch_data.cache()
# Perform analysis
results = batch_data.map(analyze).collect()
# Critical: unpersist when done
batch_data.unpersist()
return results
# Process multiple batches without memory accumulation
for batch_id in range(100):
results = process_batch(batch_id, "/data/input")
save_results(results)
Monitor cache usage programmatically:
# Check if RDD is cached
print(f"Is cached: {rdd.is_cached}")
# Get storage level
print(f"Storage level: {rdd.getStorageLevel()}")
# Inspect cache statistics via SparkContext
print(sc._jsc.sc().getRDDStorageInfo())
Checkpoint vs Persistence
Checkpointing differs from persistence by truncating the lineage graph and writing to reliable storage:
# Set checkpoint directory (HDFS in production)
sc.setCheckpointDir("/tmp/checkpoints")
# Long lineage chain
data = sc.textFile("input.txt")
for i in range(100):
data = data.map(lambda x: complex_transformation(x, i))
# Checkpoint truncates lineage, persists to HDFS
data.checkpoint()
data.count() # Materializes the checkpoint
# vs persistence which keeps lineage
data.cache()
data.count() # Keeps full lineage graph
Use checkpointing for:
- Very long lineage chains that risk stack overflow
- Fault tolerance in unreliable clusters
- Breaking circular dependencies in iterative algorithms
Use persistence for:
- Performance optimization within a single job
- Interactive data exploration
- Reusing intermediate results
Performance Optimization Strategies
Combine persistence with partitioning for optimal performance:
# Repartition before caching for better parallelism
unbalanced_rdd = sc.textFile("skewed_data.txt")
# Repartition to balance load, then cache
balanced_rdd = unbalanced_rdd.repartition(200)
balanced_rdd.cache()
# Use coalesce before caching to reduce memory footprint
large_filtered = sc.textFile("huge_file.txt") \
.filter(lambda x: "critical" in x)
# Reduce partitions since filtering reduced data size
optimized = large_filtered.coalesce(50)
optimized.cache()
Serialize custom objects efficiently:
from pyspark import StorageLevel
class DataRecord:
def __init__(self, id, value, metadata):
self.id = id
self.value = value
self.metadata = metadata
# Use serialized storage for custom objects
records = sc.parallelize([DataRecord(i, i*2, f"meta_{i}") for i in range(1000000)])
# MEMORY_ONLY_SER reduces memory footprint for objects
records.persist(StorageLevel.MEMORY_ONLY_SER)
The choice between cache() and persist() with specific storage levels directly impacts application performance and resource utilization. Profile your workload, monitor memory usage, and adjust persistence strategies accordingly.