PySpark - Cache and Persist DataFrame

PySpark operates on lazy evaluation, meaning transformations like `filter()`, `select()`, and `join()` aren't executed immediately. Instead, Spark builds a logical execution plan and only computes...

Key Insights

  • Caching eliminates redundant computation in PySpark’s lazy evaluation model by storing intermediate DataFrames in memory or disk, delivering 2-10x speedups for iterative workloads
  • Use cache() for memory-only storage with default settings, but leverage persist() with explicit storage levels (MEMORY_AND_DISK, MEMORY_AND_DISK_SER) when working with DataFrames larger than available cluster memory
  • Always unpersist DataFrames when finished to prevent memory pressure—Spark’s LRU eviction helps but explicit cleanup gives you control over resource utilization in production pipelines

Introduction to DataFrame Caching in PySpark

PySpark operates on lazy evaluation, meaning transformations like filter(), select(), and join() aren’t executed immediately. Instead, Spark builds a logical execution plan and only computes results when you call an action like count(), show(), or write(). This design enables query optimization, but it has a critical implication: every action triggers a complete recomputation from the original data source.

Consider a scenario where you load a large dataset, apply several transformations, and then run multiple analyses. Without caching, Spark recomputes the entire transformation chain for each action—reading from disk, applying filters, performing joins, and aggregations repeatedly. This redundant computation wastes CPU cycles, I/O bandwidth, and time.

Here’s a concrete example demonstrating recomputation:

from pyspark.sql import SparkSession
import time

spark = SparkSession.builder.appName("CachingDemo").getOrCreate()

# Load a DataFrame with transformations
df = spark.range(0, 10000000).selectExpr("id", "id * 2 as doubled", "id % 100 as category")
filtered_df = df.filter("category < 50")

# First action - computes from scratch
start = time.time()
count1 = filtered_df.count()
print(f"First count: {count1}, Time: {time.time() - start:.2f}s")

# Second action - recomputes everything again
start = time.time()
count2 = filtered_df.filter("doubled > 1000").count()
print(f"Second count: {count2}, Time: {time.time() - start:.2f}s")

Each action triggers a full scan and computation. Caching breaks this pattern by materializing the DataFrame after the first computation, storing it in memory or disk for subsequent reuse.

Understanding cache() Method

The cache() method is the simplest way to persist a DataFrame. It’s a shorthand for persist(StorageLevel.MEMORY_ONLY), storing the DataFrame exclusively in memory using deserialized Java objects. This approach offers the fastest access times since data remains in-memory and doesn’t require deserialization.

Use cache() when your DataFrame fits comfortably in cluster memory and you’ll access it multiple times. It’s ideal for iterative algorithms, interactive analysis, and scenarios where you branch into multiple computation paths from a common DataFrame.

Here’s how to implement caching:

from pyspark.sql import SparkSession
import time

spark = SparkSession.builder.appName("CacheExample").getOrCreate()

# Create and transform DataFrame
df = spark.range(0, 10000000).selectExpr(
    "id", 
    "id * 2 as doubled", 
    "id % 100 as category",
    "rand() as random_value"
)
filtered_df = df.filter("category < 50")

# Without cache
start = time.time()
count1 = filtered_df.count()
time1 = time.time() - start

start = time.time()
count2 = filtered_df.filter("doubled > 1000").count()
time2 = time.time() - start

print(f"Without cache - First: {time1:.2f}s, Second: {time2:.2f}s")

# With cache
cached_df = filtered_df.cache()

start = time.time()
count1 = cached_df.count()  # Triggers caching
time1 = time.time() - start

start = time.time()
count2 = cached_df.filter("doubled > 1000").count()  # Uses cached data
time2 = time.time() - start

print(f"With cache - First: {time1:.2f}s, Second: {time2:.2f}s")

You can verify caching through the DataFrame’s storage level:

# Check if DataFrame is cached
print(f"Storage Level: {cached_df.storageLevel}")
print(f"Use Memory: {cached_df.storageLevel.useMemory}")
print(f"Use Disk: {cached_df.storageLevel.useDisk}")

The Spark UI (typically at http://localhost:4040) provides detailed caching metrics under the Storage tab, showing memory used, number of partitions cached, and cache hit rates.

The persist() Method and Storage Levels

While cache() works well for DataFrames that fit in memory, persist() offers fine-grained control through storage levels. This flexibility is crucial for production workloads where memory constraints require trade-offs between speed, memory usage, and fault tolerance.

Key storage levels include:

  • MEMORY_ONLY: Same as cache(), stores deserialized objects in memory only
  • MEMORY_AND_DISK: Spills to disk if memory is insufficient
  • MEMORY_ONLY_SER: Stores serialized objects, using less memory but requiring deserialization
  • MEMORY_AND_DISK_SER: Serialized storage with disk spillover
  • DISK_ONLY: Stores only on disk, useful when memory is extremely limited

Here’s how to use different storage levels:

from pyspark import StorageLevel
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("PersistLevels").getOrCreate()

df = spark.range(0, 50000000).selectExpr(
    "id",
    "id * 2 as value",
    "cast(rand() * 100 as int) as category"
)

# Memory and disk persistence (recommended for large DataFrames)
df_memory_disk = df.persist(StorageLevel.MEMORY_AND_DISK)
df_memory_disk.count()  # Materialize

# Serialized persistence (saves memory at cost of CPU)
df_serialized = df.persist(StorageLevel.MEMORY_AND_DISK_SER)
df_serialized.count()

# Disk-only persistence (when memory is critical)
df_disk = df.persist(StorageLevel.DISK_ONLY)
df_disk.count()

For large DataFrames that exceed cluster memory, MEMORY_AND_DISK_SER offers an excellent balance. Serialization reduces memory footprint by 2-5x compared to deserialized objects, and disk spillover prevents out-of-memory errors:

# Best practice for large DataFrames
large_df = spark.read.parquet("s3://bucket/large-dataset/")
processed_df = large_df.filter("status = 'active'") \
                       .groupBy("user_id") \
                       .agg({"transaction_amount": "sum"}) \
                       .persist(StorageLevel.MEMORY_AND_DISK_SER)

# Use the cached DataFrame multiple times
processed_df.count()
processed_df.filter("sum(transaction_amount) > 1000").show()
processed_df.write.parquet("s3://bucket/output/")

Best Practices and When to Cache

Caching isn’t always beneficial. The decision depends on DataFrame size, reuse frequency, and transformation complexity. Cache when:

  1. Iterative algorithms repeatedly access the same DataFrame (machine learning, graph processing)
  2. Multiple actions on the same DataFrame are required
  3. Branching transformations create multiple computation paths from a common DataFrame
  4. Expensive transformations (complex joins, aggregations) precede multiple downstream operations

Here’s a machine learning pipeline demonstrating effective caching:

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression

# Load and prepare data
raw_df = spark.read.parquet("s3://bucket/training-data/")
prepared_df = raw_df.filter("label is not null") \
                    .dropDuplicates() \
                    .fillna(0) \
                    .persist(StorageLevel.MEMORY_AND_DISK)

# Multiple operations benefit from cache
train_df, test_df = prepared_df.randomSplit([0.8, 0.2], seed=42)
train_df.cache()
test_df.cache()

# Training and evaluation both use cached data
assembler = VectorAssembler(inputCols=["feature1", "feature2"], outputCol="features")
train_assembled = assembler.transform(train_df)
test_assembled = assembler.transform(test_df)

lr = LogisticRegression()
model = lr.fit(train_assembled)
predictions = model.transform(test_assembled)

Branching operations also benefit significantly:

# Common base DataFrame used for multiple analyses
base_df = spark.read.parquet("s3://bucket/events/") \
               .filter("timestamp > '2024-01-01'") \
               .cache()

# Branch 1: User analytics
user_stats = base_df.groupBy("user_id").agg({"event_id": "count"})

# Branch 2: Event type analysis
event_stats = base_df.groupBy("event_type").agg({"duration": "avg"})

# Branch 3: Hourly trends
hourly_stats = base_df.groupBy("hour").agg({"user_id": "countDistinct"})

Avoid caching when:

  • DataFrame is used only once
  • Transformations are simple and fast
  • Data size is very small
  • Memory is severely constrained
# Anti-pattern: unnecessary caching
small_df = spark.range(100).cache()  # Overhead exceeds benefit
result = small_df.count()  # Used only once

Unpersisting and Memory Management

Cached DataFrames consume cluster memory until explicitly unpersisted or evicted by Spark’s LRU (Least Recently Used) algorithm. In long-running applications, failing to unpersist leads to memory pressure, reduced performance, and potential out-of-memory errors.

Always unpersist DataFrames when they’re no longer needed:

# Process data with caching
df = spark.read.parquet("s3://bucket/data/")
cached_df = df.filter("status = 'active'").cache()

# Perform operations
result1 = cached_df.count()
result2 = cached_df.groupBy("category").count().collect()

# Explicitly release memory
cached_df.unpersist()

You can monitor storage memory programmatically:

# Check cache status
print(f"Is cached: {cached_df.is_cached}")

# After operations complete
cached_df.unpersist()
print(f"Is cached after unpersist: {cached_df.is_cached}")

For long-running applications, implement a cache management strategy:

def process_daily_batch(date):
    df = spark.read.parquet(f"s3://bucket/data/date={date}")
    cached_df = df.filter("valid = true").persist(StorageLevel.MEMORY_AND_DISK)
    
    try:
        # Perform multiple operations
        cached_df.write.mode("overwrite").parquet(f"s3://bucket/output/date={date}")
        stats = cached_df.groupBy("category").count().collect()
        return stats
    finally:
        # Ensure cleanup even if errors occur
        cached_df.unpersist()

Performance Benchmarking

Let’s quantify caching impact with a realistic benchmark:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, rand, when
import time

spark = SparkSession.builder.appName("CacheBenchmark").getOrCreate()

# Create a complex DataFrame
df = spark.range(0, 20000000).selectExpr(
    "id",
    "cast(rand() * 1000 as int) as value",
    "cast(rand() * 100 as int) as category"
).withColumn(
    "label",
    when(col("value") > 500, 1).otherwise(0)
)

# Apply transformations
transformed_df = df.filter("category < 50") \
                   .groupBy("category", "label") \
                   .agg({"value": "avg", "id": "count"})

# Benchmark without cache
start = time.time()
count1 = transformed_df.count()
sum1 = transformed_df.agg({"avg(value)": "sum"}).collect()
filter1 = transformed_df.filter("count(id) > 100000").count()
no_cache_time = time.time() - start

# Benchmark with cache
cached_df = transformed_df.cache()
start = time.time()
count2 = cached_df.count()  # Materializes cache
sum2 = cached_df.agg({"avg(value)": "sum"}).collect()
filter2 = cached_df.filter("count(id) > 100000").count()
cache_time = time.time() - start

print(f"Without cache: {no_cache_time:.2f}s")
print(f"With cache: {cache_time:.2f}s")
print(f"Speedup: {no_cache_time/cache_time:.2f}x")

cached_df.unpersist()

In typical scenarios, caching delivers 2-10x speedups depending on transformation complexity and reuse frequency. The first action with cache takes similar time (or slightly longer due to caching overhead), but subsequent actions are dramatically faster.

Caching is a powerful optimization technique in PySpark, but it requires thoughtful application. Use it for iterative workloads and frequently accessed DataFrames, choose appropriate storage levels based on memory availability, and always clean up to maintain cluster health. Master these patterns, and you’ll build efficient, scalable data pipelines that maximize cluster resources.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.