PySpark: Optimization Techniques
Distributed computing promises horizontal scalability, but that promise comes with a catch: poor code that runs slowly on a single machine runs catastrophically slowly across a cluster. I've seen...
Key Insights
- Understanding Spark’s Catalyst optimizer and using
.explain()to analyze execution plans is the foundation of all PySpark optimization work - Shuffle operations are the primary performance killer in distributed computing—broadcast joins and salting techniques can eliminate most shuffle-related bottlenecks
- Proper partitioning strategy combined with columnar file formats like Parquet can improve job performance by 10-100x without touching your transformation logic
Introduction to PySpark Performance
Distributed computing promises horizontal scalability, but that promise comes with a catch: poor code that runs slowly on a single machine runs catastrophically slowly across a cluster. I’ve seen PySpark jobs that should complete in minutes drag on for hours because developers treated Spark like pandas with more RAM.
The fundamental challenge is that Spark distributes data across nodes, and moving that data between nodes—shuffling—is expensive. Network I/O, serialization overhead, and disk spills all compound. Add data skew to the mix, and you’ll watch one executor struggle while the rest sit idle.
The good news: most PySpark performance problems fall into predictable categories, and the solutions are well-established. Let’s work through them systematically.
Understanding the Catalyst Optimizer & Tungsten Engine
Before optimizing anything, you need to understand what Spark actually does with your code. The Catalyst optimizer transforms your DataFrame operations through four phases: analysis, logical optimization, physical planning, and code generation.
When you write df.filter(col("status") == "active").select("id", "name"), Catalyst doesn’t execute operations in that order. It builds a logical plan, applies optimization rules (like pushing filters before projections), generates multiple physical plans, selects the cheapest one based on cost estimation, and then Tungsten generates optimized bytecode.
The .explain() method reveals this process:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
spark = SparkSession.builder.appName("OptimizationDemo").getOrCreate()
# Create sample DataFrames
orders = spark.read.parquet("s3://data/orders/")
customers = spark.read.parquet("s3://data/customers/")
# Unoptimized query pattern
result = orders.join(customers, "customer_id") \
.filter(col("order_date") > "2024-01-01") \
.select("order_id", "customer_name", "total")
# Analyze the execution plan
result.explain(mode="extended")
The extended explain output shows four plans: parsed logical, analyzed logical, optimized logical, and physical. Pay attention to the physical plan—look for Exchange nodes (shuffles), BroadcastHashJoin vs SortMergeJoin, and filter placement.
Compare this with an optimized version where we filter before joining:
# Optimized: filter early to reduce data volume before join
filtered_orders = orders.filter(col("order_date") > "2024-01-01")
result = filtered_orders.join(customers, "customer_id") \
.select("order_id", "customer_name", "total")
result.explain(mode="extended")
Catalyst should push the filter down automatically, but complex transformations can prevent this. Always verify with .explain().
Partitioning Strategies
Partitioning determines how data is distributed across executors. Too few partitions underutilize your cluster; too many create scheduling overhead and small file problems.
The default parallelism is typically 200 partitions, which is rarely optimal. A good starting point: aim for partitions between 128MB and 256MB each. For a 100GB dataset, that’s roughly 400-800 partitions.
# Check current partition count
print(f"Current partitions: {df.rdd.getNumPartitions()}")
# Repartition for parallelism (causes full shuffle)
df_repartitioned = df.repartition(500)
# Coalesce to reduce partitions (no shuffle, but can cause skew)
df_coalesced = df.coalesce(100)
# Repartition by key for join optimization
orders_by_customer = orders.repartition(200, "customer_id")
customers_by_id = customers.repartition(200, "customer_id")
# Join now avoids shuffle—data is co-located
result = orders_by_customer.join(customers_by_id, "customer_id")
The critical distinction: repartition() performs a full shuffle to create exactly N evenly-distributed partitions. coalesce() combines existing partitions without shuffling, which is faster but can create uneven partition sizes.
For reads, partition pruning eliminates entire partitions from scans:
# Assuming data is partitioned by date on disk
# Spark only reads relevant partitions
df = spark.read.parquet("s3://data/events/") \
.filter(col("event_date") == "2024-06-15")
This only works if your storage is physically partitioned by the filter column.
Avoiding Shuffles and Managing Data Skew
Shuffles move data across the network. They’re necessary for operations like groupBy, join, and distinct, but they’re expensive. Your goal is to minimize shuffle data volume and avoid shuffles entirely when possible.
Broadcast joins eliminate shuffles for small tables by sending the entire small table to every executor:
from pyspark.sql.functions import broadcast
# Small dimension table (< 10MB is a safe threshold)
countries = spark.read.parquet("s3://data/countries/") # ~50KB
# Large fact table
transactions = spark.read.parquet("s3://data/transactions/") # 500GB
# Broadcast the small table—no shuffle required
result = transactions.join(broadcast(countries), "country_code")
Spark auto-broadcasts tables under spark.sql.autoBroadcastJoinThreshold (default 10MB), but explicit broadcasting makes intent clear and works even when statistics are unavailable.
Data skew is trickier. When one key has disproportionately more records than others, one executor handles most of the work. The salting technique distributes skewed keys across multiple partitions:
from pyspark.sql.functions import lit, concat, floor, rand
# Assume "company_id" is skewed—one company has 80% of orders
SALT_BUCKETS = 10
# Salt the large table
orders_salted = orders.withColumn(
"salt",
floor(rand() * SALT_BUCKETS).cast("int")
).withColumn(
"salted_key",
concat(col("company_id"), lit("_"), col("salt"))
)
# Explode the small table to match all salt values
from pyspark.sql.functions import explode, array
companies_exploded = companies.withColumn(
"salt",
explode(array([lit(i) for i in range(SALT_BUCKETS)]))
).withColumn(
"salted_key",
concat(col("company_id"), lit("_"), col("salt"))
)
# Join on salted key—work distributed evenly
result = orders_salted.join(companies_exploded, "salted_key") \
.drop("salt", "salted_key")
This trades increased data volume on the small table for even distribution of the large table.
Caching and Persistence
Caching stores intermediate results in memory, avoiding recomputation. It’s valuable when you reuse a DataFrame multiple times, but it consumes cluster memory.
from pyspark import StorageLevel
# Multi-step pipeline with reuse
base_df = spark.read.parquet("s3://data/events/") \
.filter(col("event_type").isin(["click", "purchase"])) \
.withColumn("event_hour", hour(col("event_time")))
# Cache because we'll use this twice
base_df.cache()
# First use
hourly_stats = base_df.groupBy("event_hour").count()
hourly_stats.write.parquet("s3://output/hourly_stats/")
# Second use
daily_summary = base_df.groupBy("event_type").agg(sum("value"))
daily_summary.write.parquet("s3://output/daily_summary/")
# Release memory when done
base_df.unpersist()
Storage levels control the trade-off between memory usage and recomputation cost:
# Memory only (default for .cache())
df.persist(StorageLevel.MEMORY_ONLY)
# Spill to disk if memory is insufficient
df.persist(StorageLevel.MEMORY_AND_DISK)
# Serialize to reduce memory footprint (slower access)
df.persist(StorageLevel.MEMORY_ONLY_SER)
Don’t cache everything. Cache only DataFrames that are expensive to compute and used multiple times. Unnecessary caching evicts useful data from memory.
Efficient Data Serialization & File Formats
Parquet is the standard for analytical workloads. It’s columnar, compressed, and supports predicate pushdown and column pruning natively.
# Column pruning: only read columns you need
df = spark.read.parquet("s3://data/wide_table/") \
.select("id", "timestamp", "value") # Reads only 3 columns from disk
# Predicate pushdown: filter at storage layer
df = spark.read.parquet("s3://data/events/") \
.filter(col("country") == "US") # Pushed to Parquet reader
For serialization, Kryo is faster than Java serialization:
spark = SparkSession.builder \
.appName("ProductionJob") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("spark.kryo.registrationRequired", "false") \
.getOrCreate()
Cluster Configuration & Resource Tuning
Finally, proper resource configuration ensures your optimized code runs on a well-tuned cluster:
spark = SparkSession.builder \
.appName("ProductionWorkload") \
.config("spark.executor.memory", "8g") \
.config("spark.executor.cores", "4") \
.config("spark.executor.memoryOverhead", "1g") \
.config("spark.sql.shuffle.partitions", "400") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.config("spark.dynamicAllocation.enabled", "true") \
.config("spark.dynamicAllocation.minExecutors", "2") \
.config("spark.dynamicAllocation.maxExecutors", "50") \
.getOrCreate()
Adaptive Query Execution (AQE), enabled by default in Spark 3.0+, dynamically adjusts shuffle partitions and join strategies based on runtime statistics. It handles many optimization problems automatically, but understanding the fundamentals remains essential for complex workloads.