How to Cache a DataFrame in PySpark
If you've ever watched a Spark job run the same expensive transformation multiple times, you've experienced the cost of ignoring caching. Spark's lazy evaluation model means it doesn't store...
Key Insights
- Caching stores a DataFrame in memory (or disk) after its first computation, preventing Spark from recomputing the entire lineage every time you trigger an action on that DataFrame.
- Use
cache()for simple memory-only storage, but preferpersist()with explicit storage levels when you need control over memory pressure and fault tolerance. - Always call
unpersist()when you’re done with a cached DataFrame—Spark’s automatic eviction isn’t always timely, and holding onto cached data unnecessarily starves other operations of resources.
Introduction
If you’ve ever watched a Spark job run the same expensive transformation multiple times, you’ve experienced the cost of ignoring caching. Spark’s lazy evaluation model means it doesn’t store intermediate results by default. Every action—count(), show(), write()—triggers a full recomputation from the source data through every transformation in the lineage.
Caching fixes this. When you cache a DataFrame, Spark stores the computed result after the first action, and subsequent actions read from that stored copy instead of recomputing everything. For workflows that reuse DataFrames—think iterative machine learning algorithms, exploratory analysis, or multi-branch pipelines—caching can cut execution time dramatically.
But caching isn’t free. It consumes cluster memory, and misusing it can actually hurt performance. This article covers how caching works, when to use it, and how to do it correctly.
Understanding Spark’s Lazy Evaluation
Spark builds a directed acyclic graph (DAG) of transformations when you chain operations on a DataFrame. Nothing actually executes until you call an action. This lazy evaluation enables Spark’s query optimizer to reorder and combine operations for efficiency.
The downside: Spark doesn’t remember intermediate results. Consider this example:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as spark_sum
spark = SparkSession.builder.appName("CachingDemo").getOrCreate()
# Load and transform data
raw_df = spark.read.parquet("s3://bucket/large-dataset/")
filtered_df = raw_df.filter(col("status") == "active")
aggregated_df = filtered_df.groupBy("region").agg(spark_sum("revenue").alias("total_revenue"))
# Multiple actions on the same DataFrame
print(aggregated_df.count()) # Full computation: read -> filter -> aggregate
aggregated_df.show() # Full computation again: read -> filter -> aggregate
aggregated_df.write.parquet("output/") # And again: read -> filter -> aggregate
Each action triggers the entire pipeline from scratch. If your source data is large or your transformations are expensive, you’re wasting significant compute time. Check the Spark UI’s SQL tab—you’ll see three separate jobs, each reading from the source and applying all transformations.
cache() vs persist(): Core Methods
PySpark provides two methods for caching: cache() and persist(). They’re closely related, but persist() gives you more control.
Using cache()
The cache() method is a shorthand for persist(StorageLevel.MEMORY_AND_DISK). It stores the DataFrame in memory across the cluster, spilling to disk if memory runs out.
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as spark_sum
spark = SparkSession.builder.appName("CacheExample").getOrCreate()
raw_df = spark.read.parquet("s3://bucket/large-dataset/")
filtered_df = raw_df.filter(col("status") == "active")
aggregated_df = filtered_df.groupBy("region").agg(spark_sum("revenue").alias("total_revenue"))
# Cache the DataFrame
aggregated_df.cache()
# First action triggers computation and caching
print(aggregated_df.count()) # Computes and stores in memory
# Subsequent actions read from cache
aggregated_df.show() # Reads from cache—fast
aggregated_df.write.parquet("output/") # Reads from cache—fast
Important: cache() is lazy. Calling it doesn’t immediately compute and store the DataFrame. The caching happens when the first action executes. This catches people off guard—if you call cache() but never trigger an action, nothing gets cached.
Using persist() with Storage Levels
When you need finer control, use persist() with an explicit storage level:
from pyspark import StorageLevel
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as spark_sum
spark = SparkSession.builder.appName("PersistExample").getOrCreate()
raw_df = spark.read.parquet("s3://bucket/large-dataset/")
filtered_df = raw_df.filter(col("status") == "active")
aggregated_df = filtered_df.groupBy("region").agg(spark_sum("revenue").alias("total_revenue"))
# Persist with explicit storage level
aggregated_df.persist(StorageLevel.MEMORY_AND_DISK)
# Trigger caching
aggregated_df.count()
Available storage levels:
| Storage Level | Description |
|---|---|
MEMORY_ONLY |
Store in memory as deserialized objects. If it doesn’t fit, partitions won’t be cached. |
MEMORY_AND_DISK |
Store in memory; spill to disk if memory is insufficient. Default for cache(). |
DISK_ONLY |
Store only on disk. Useful for very large DataFrames that don’t fit in memory. |
MEMORY_ONLY_SER |
Store in memory as serialized objects. More space-efficient but slower to read. |
MEMORY_AND_DISK_SER |
Serialized storage with disk spillover. Good balance for large datasets. |
For most use cases, MEMORY_AND_DISK (the default) works well. Use MEMORY_AND_DISK_SER when memory pressure is high—serialization reduces memory footprint at the cost of CPU overhead during reads.
# For memory-constrained clusters with large DataFrames
aggregated_df.persist(StorageLevel.MEMORY_AND_DISK_SER)
When to Cache (and When Not To)
Caching isn’t always beneficial. Here’s when it makes sense:
Cache when:
- A DataFrame is used in multiple actions or branches of your pipeline
- The DataFrame is expensive to compute (complex joins, aggregations, UDFs)
- The DataFrame fits reasonably in cluster memory
- You’re running iterative algorithms (ML training, graph processing)
Don’t cache when:
- The DataFrame is used only once
- The DataFrame is trivially cheap to recompute (simple filters on already-cached data)
- Your cluster is memory-constrained and the DataFrame is large
- You’re reading from a fast source that’s cheaper to re-read than to cache
A common mistake is caching everything. This backfires when cached data evicts other cached data or reduces memory available for shuffles and joins. Be selective.
# Good: Cache after expensive operations, before multiple uses
expensive_df = (
raw_df
.join(lookup_df, "key")
.groupBy("category")
.agg(spark_sum("value").alias("total"))
)
expensive_df.cache()
# Use multiple times
expensive_df.filter(col("total") > 1000).show()
expensive_df.filter(col("total") < 100).show()
expensive_df.write.parquet("output/")
# Bad: Caching a DataFrame used once
one_time_df = raw_df.filter(col("date") == "2024-01-01")
one_time_df.cache() # Pointless—only used once
one_time_df.write.parquet("output/")
Verifying and Monitoring Cached Data
After calling cache() or persist(), you should verify that caching actually happened.
Programmatic Checks
from pyspark import StorageLevel
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("VerifyCache").getOrCreate()
df = spark.range(1000000)
df.cache()
# Before any action—not yet cached
print(f"Is cached: {df.is_cached}") # True (marked for caching)
print(f"Storage level: {df.storageLevel}") # Memory and Disk
# Trigger caching
df.count()
# Now actually cached
print(f"Is cached: {df.is_cached}") # True
print(f"Storage level: {df.storageLevel}") # Memory and Disk Deserialized 1x Replicated
Note that is_cached returns True as soon as you call cache(), even before the data is actually stored. It indicates intent, not completion.
Spark UI Storage Tab
The most reliable way to verify caching is the Spark UI. Navigate to the Storage tab to see:
- Which RDDs/DataFrames are cached
- How much memory and disk they consume
- What fraction is cached (some partitions might not fit)
- The storage level in use
If your DataFrame doesn’t appear in the Storage tab after triggering an action, caching failed—usually due to memory pressure.
Releasing Cached Data
Cached DataFrames consume cluster resources until explicitly released or evicted. Don’t rely on automatic cleanup.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("UnpersistExample").getOrCreate()
df = spark.range(1000000)
df.cache()
df.count() # Cached
# When you're done with the DataFrame
df.unpersist()
# Verify
print(f"Is cached: {df.is_cached}") # False
By default, unpersist() is asynchronous—it returns immediately while Spark removes the cached data in the background. If you need synchronous behavior (waiting until the cache is fully cleared), pass blocking=True:
df.unpersist(blocking=True) # Waits until cache is cleared
Call unpersist() as soon as you’re done with a cached DataFrame. This is especially important in long-running applications or notebooks where cached data accumulates.
Spark does have automatic eviction using LRU (least recently used) policy when memory pressure is high, but relying on this leads to unpredictable performance. Explicit cleanup is better.
Best Practices Summary
Follow these guidelines for effective caching:
-
Cache after expensive operations. Place
cache()after joins, aggregations, and complex transformations—not on raw data reads. -
Trigger caching explicitly. Call an action like
count()immediately aftercache()to ensure the data is stored before subsequent operations. -
Monitor memory usage. Check the Spark UI Storage tab regularly. If cached data is being evicted, consider using
MEMORY_AND_DISK_SERor reducing what you cache. -
Unpersist when done. Don’t let cached DataFrames linger. Call
unpersist()as soon as you’re finished with them. -
Consider serialization for large datasets.
MEMORY_AND_DISK_SERtrades CPU time for memory efficiency—often a good tradeoff on memory-constrained clusters. -
Don’t cache everything. Be selective. Caching one-time-use DataFrames wastes memory and can hurt overall performance.
-
Cache at the right granularity. Cache the DataFrame that’s actually reused, not its ancestors. Caching too early means storing more data than necessary.
Caching is one of the most effective performance optimizations in Spark, but only when used deliberately. Understand your data flow, identify reuse points, and cache strategically.