Apache Spark - Optimize Joins (Broadcast, Sort-Merge, Shuffle Hash)
Joins are the most expensive operations in distributed data processing. When you join two DataFrames in Spark, the framework must ensure matching keys end up on the same executor. This typically...
Key Insights
- Broadcast joins eliminate shuffle overhead entirely and can speed up joins by 10-100x when one table is small enough to fit in executor memory
- Sort-merge join is Spark’s default for large tables, but pre-bucketing your data on join keys can skip the expensive sort phase entirely
- Adaptive Query Execution (AQE) in Spark 3.0+ automatically handles many join optimizations, including skew mitigation, but understanding manual strategies remains essential for predictable performance
Introduction to Join Performance in Spark
Joins are the most expensive operations in distributed data processing. When you join two DataFrames in Spark, the framework must ensure matching keys end up on the same executor. This typically means shuffling data across the network—moving potentially terabytes of data between nodes.
The cost is staggering. A naive join between two 100GB tables might shuffle 200GB across your cluster, saturating network bandwidth and causing massive disk I/O as Spark spills intermediate data. Meanwhile, a well-optimized join on the same data might shuffle nothing at all.
Spark supports three primary join strategies: broadcast hash join, sort-merge join, and shuffle hash join. Each has distinct performance characteristics, and choosing correctly can mean the difference between a 10-minute job and a 10-hour one.
Broadcast Hash Join
Broadcast hash join is the fastest strategy when one side of the join is small. Spark serializes the smaller DataFrame and sends a complete copy to every executor. Each executor then builds an in-memory hash table and probes it against the local partitions of the larger table. No shuffle required.
By default, Spark broadcasts tables smaller than 10MB:
// Check current threshold (default: 10MB)
spark.conf.get("spark.sql.autoBroadcastJoinThreshold")
// Increase threshold to 100MB
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "104857600")
// Disable auto-broadcast entirely
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
When you know a table is small enough, use an explicit broadcast hint:
import org.apache.spark.sql.functions.broadcast
val largeOrders = spark.table("orders") // 500GB
val smallCountries = spark.table("countries") // 5MB
// Force broadcast regardless of threshold
val result = largeOrders.join(
broadcast(smallCountries),
Seq("country_code")
)
In PySpark:
from pyspark.sql.functions import broadcast
result = large_orders.join(
broadcast(small_countries),
on="country_code"
)
Verify Spark chose broadcast join by examining the physical plan:
result.explain(true)
Look for BroadcastHashJoin in the output:
== Physical Plan ==
*(2) BroadcastHashJoin [country_code#10], [country_code#25], Inner, BuildRight
:- *(2) Filter isnotnull(country_code#10)
: +- *(2) ColumnarToRow
: +- FileScan parquet [order_id#9,country_code#10]
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]))
+- *(1) Filter isnotnull(country_code#25)
+- *(1) ColumnarToRow
+- FileScan parquet [country_code#25,country_name#26]
Warning: Broadcasting tables that are too large causes executor OOM errors. The broadcast table must fit in memory on every executor, plus overhead for the hash table structure. A 500MB CSV might expand to 2GB in memory. Be conservative with thresholds.
Sort-Merge Join
Sort-merge join is Spark’s default strategy for joining large tables. Both sides are shuffled by join key, sorted within each partition, then merged using a streaming algorithm that requires minimal memory.
The process has three phases:
- Shuffle: Repartition both tables so matching keys land on the same executor
- Sort: Sort each partition by the join key
- Merge: Walk through both sorted partitions simultaneously, matching keys
// Force sort-merge join by disabling broadcast
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
val orders = spark.table("orders") // 500GB
val customers = spark.table("customers") // 200GB
val result = orders.join(customers, Seq("customer_id"))
result.explain()
The plan shows SortMergeJoin:
== Physical Plan ==
*(5) SortMergeJoin [customer_id#10], [customer_id#50], Inner
:- *(2) Sort [customer_id#10 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(customer_id#10, 200)
: +- *(1) Filter isnotnull(customer_id#10)
+- *(4) Sort [customer_id#50 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(customer_id#50, 200)
+- *(3) Filter isnotnull(customer_id#50)
Notice the Exchange (shuffle) and Sort operations. You can eliminate both by pre-bucketing tables:
// Create bucketed tables (do this once during ETL)
orders.write
.bucketBy(200, "customer_id")
.sortBy("customer_id")
.saveAsTable("orders_bucketed")
customers.write
.bucketBy(200, "customer_id")
.sortBy("customer_id")
.saveAsTable("customers_bucketed")
// Join bucketed tables - no shuffle or sort needed
val ordersBucketed = spark.table("orders_bucketed")
val customersBucketed = spark.table("customers_bucketed")
val result = ordersBucketed.join(customersBucketed, Seq("customer_id"))
The plan now shows no Exchange or Sort—just the merge:
== Physical Plan ==
*(1) SortMergeJoin [customer_id#10], [customer_id#50], Inner
:- *(1) Filter isnotnull(customer_id#10)
: +- FileScan parquet [customer_id#10] Batched: true, Bucketed: true
+- *(1) Filter isnotnull(customer_id#50)
+- FileScan parquet [customer_id#50] Batched: true, Bucketed: true
Bucketing requires both tables to have the same number of buckets and compatible bucket columns.
Shuffle Hash Join
Shuffle hash join is a middle ground: it shuffles data like sort-merge but builds hash tables instead of sorting. This works well when the smaller side (after shuffling) fits in memory per partition.
Spark doesn’t use shuffle hash join by default. Enable it explicitly:
spark.conf.set("spark.sql.join.preferSortMergeJoin", "false")
spark.conf.set("spark.sql.shuffledHashJoinFactor", "3")
Shuffle hash join outperforms sort-merge when:
- One table is significantly smaller than the other (but too large to broadcast)
- The smaller table’s partitions fit in executor memory
- Sort overhead exceeds hash table construction cost
# Compare execution times
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")
sort_merge_result = medium_table.join(large_table, "key")
sort_merge_result.write.mode("overwrite").parquet("/tmp/sm_result")
spark.conf.set("spark.sql.join.preferSortMergeJoin", "false")
shuffle_hash_result = medium_table.join(large_table, "key")
shuffle_hash_result.write.mode("overwrite").parquet("/tmp/sh_result")
Choosing the Right Join Strategy
| Strategy | Best When | Shuffle | Memory Requirement | Handles Skew |
|---|---|---|---|---|
| Broadcast Hash | One table < 100MB | None | Full table per executor | Yes |
| Sort-Merge | Both tables large | Both sides | Low (streaming) | Poorly |
| Shuffle Hash | One side 100MB-1GB post-shuffle | Both sides | Partition fits in memory | Poorly |
Decision framework:
- Can you broadcast? If either table is under 100MB (or fits in executor memory), broadcast it.
- Are tables pre-bucketed? Use sort-merge with bucketed tables for zero-shuffle joins.
- Is one side much smaller? Try shuffle hash join with appropriate memory settings.
- Both sides huge? Sort-merge with sufficient partitions (aim for 100-200MB per partition).
Handling Data Skew in Joins
Data skew occurs when certain join keys appear far more frequently than others. A single executor gets stuck processing millions of rows while others sit idle.
Key salting distributes hot keys across multiple partitions:
from pyspark.sql.functions import col, lit, concat, rand, floor, explode, array
# Identify skewed keys (keys with >1M rows)
skewed_keys = ["US", "CN", "IN"]
salt_factor = 10
# Salt the large table's skewed keys
large_df_salted = large_df.withColumn(
"salted_key",
when(
col("country_code").isin(skewed_keys),
concat(col("country_code"), lit("_"), floor(rand() * salt_factor))
).otherwise(col("country_code"))
)
# Explode small table to match all salt values
small_df_exploded = small_df.withColumn(
"salt_values",
when(
col("country_code").isin(skewed_keys),
array([lit(f"{c}_{i}") for c in skewed_keys for i in range(salt_factor)])
).otherwise(array(col("country_code")))
).select("*", explode("salt_values").alias("salted_key")).drop("salt_values")
# Join on salted key
result = large_df_salted.join(small_df_exploded, "salted_key")
Adaptive Query Execution (AQE) handles skew automatically in Spark 3.0+:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
AQE detects skewed partitions at runtime and splits them automatically.
Monitoring and Debugging Join Performance
Always check the physical plan before running expensive joins:
// Detailed plan with statistics
df.explain(mode = "extended")
// Formatted plan (Spark 3.0+)
df.explain(mode = "formatted")
Key things to look for:
- Join type:
BroadcastHashJoin,SortMergeJoin, orShuffledHashJoin - Exchange nodes: Each exchange is a shuffle
- Sort nodes: Expensive for large datasets
- Statistics: Row counts and data sizes
In Spark UI, examine:
- Stage timeline: Long-running tasks indicate skew
- Shuffle read/write: High values suggest broadcast opportunities
- Task duration distribution: Skew shows as outlier tasks
# Collect metrics programmatically
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
result = df1.join(df2, "key")
result.cache()
result.count() # Trigger execution
# Check Spark UI at http://driver:4040/SQL/
The difference between a well-optimized join and a naive one is often 10-100x in execution time. Invest the time to understand your data sizes, check explain plans, and choose strategies deliberately.