Apache Spark - Whole Stage Code Generation
• Whole-stage code generation (WSCG) compiles entire query stages into single optimized functions, eliminating virtual function calls and improving CPU efficiency by 2-10x compared to the Volcano...
Key Insights
• Whole-stage code generation (WSCG) compiles entire query stages into single optimized functions, eliminating virtual function calls and improving CPU efficiency by 2-10x compared to the Volcano iterator model • Spark automatically enables WSCG for supported operations (filters, projections, simple aggregations) but falls back to row-by-row processing for complex operations like sorting or external data sources • Understanding when WSCG activates and how to structure queries for optimal code generation is critical for performance tuning large-scale data pipelines
Understanding the Volcano Model Limitation
Traditional query engines use the Volcano iterator model, where each operator implements a next() method that returns one row at a time. This creates significant overhead through virtual function calls, poor instruction pipelining, and inefficient CPU cache usage.
// Simplified Volcano model pseudocode
trait Operator {
def next(): Row
}
class FilterOperator(child: Operator, predicate: Row => Boolean) extends Operator {
def next(): Row = {
var row = child.next()
while (row != null && !predicate(row)) {
row = child.next()
}
row
}
}
class ProjectOperator(child: Operator, schema: Schema) extends Operator {
def next(): Row = {
val row = child.next()
if (row != null) project(row, schema) else null
}
}
For a query with multiple operators, each row traverses the entire operator tree, incurring function call overhead at every step. With billions of rows, this becomes a critical bottleneck.
How Whole-Stage Code Generation Works
WSCG fuses multiple operators into a single code-generated function. Instead of pulling rows through operators, it pushes columnar batches through generated code that processes multiple rows in tight loops.
// Example query
val df = spark.read.parquet("users")
.filter($"age" > 25)
.filter($"country" === "US")
.select($"name", $"age" * 2)
// Generated code (simplified)
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
private InternalRow scan_row;
private boolean scan_hasNext;
protected void processNext() {
while (scan_input.hasNext()) {
scan_row = scan_input.next();
// Fused filter conditions
int age = scan_row.getInt(1);
if (age > 25) {
UTF8String country = scan_row.getUTF8String(2);
if (country.equals(UTF8String.fromString("US"))) {
// Fused projection
UTF8String name = scan_row.getUTF8String(0);
int ageDoubled = age * 2;
// Emit result
scan_mutableRow.update(0, name);
scan_mutableRow.setInt(1, ageDoubled);
append(scan_mutableRow);
return;
}
}
}
}
}
The generated code eliminates virtual calls, enables JIT compiler optimizations, and processes data in tight loops with better CPU cache locality.
Identifying WSCG in Query Plans
Use explain() with the codegen format to verify WSCG activation:
val df = spark.read.parquet("transactions")
.filter($"amount" > 1000)
.groupBy($"customer_id")
.agg(sum($"amount").as("total"))
df.explain("codegen")
Look for WholeStageCodegen markers in the physical plan:
*(2) HashAggregate(keys=[customer_id#10], functions=[sum(amount#11)])
+- Exchange hashpartitioning(customer_id#10, 200)
+- *(1) HashAggregate(keys=[customer_id#10], functions=[partial_sum(amount#11)])
+- *(1) Project [customer_id#10, amount#11]
+- *(1) Filter (isnotnull(amount#11) AND (amount#11 > 1000.0))
+- *(1) FileScan parquet [customer_id#10,amount#11]
The *(1) and *(2) indicate separate code-generated stages. Operations within the same stage execute in a single generated function.
Operations Supporting WSCG
WSCG works with specific operators that can be fused efficiently:
// Supported operations
val supported = df
.filter($"status" === "active") // Filter
.select($"id", $"value" * 1.1) // Project
.where($"value" > 100) // Additional filter
.withColumn("computed", $"value" / 10) // Column operations
// Operations that break WSCG
val unsupported = df
.sort($"timestamp".desc) // Sort breaks stage
.limit(100) // Limit breaks stage
// External data source reads may not support WSCG
val jdbc = spark.read
.format("jdbc")
.option("url", "jdbc:postgresql://...")
.load() // May not generate code
Common WSCG-compatible operations:
- Filters and predicates
- Projections and column expressions
- Simple aggregations (sum, count, avg)
- Hash joins (build side)
- Columnar data sources (Parquet, ORC)
Performance Comparison
Here’s a benchmark comparing WSCG-enabled versus disabled execution:
import org.apache.spark.sql.internal.SQLConf
// Disable WSCG
spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false")
val disabledTime = time {
spark.range(1000000000)
.filter($"id" % 2 === 0)
.filter($"id" % 3 === 0)
.selectExpr("id", "id * 2 as doubled")
.count()
}
// Enable WSCG
spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
val enabledTime = time {
spark.range(1000000000)
.filter($"id" % 2 === 0)
.filter($"id" % 3 === 0)
.selectExpr("id", "id * 2 as doubled")
.count()
}
println(s"Speedup: ${disabledTime / enabledTime}x")
// Typical output: Speedup: 3-5x
Optimizing for Code Generation
Structure queries to maximize WSCG effectiveness:
// Poor: Multiple stages due to ordering
val inefficient = df
.filter($"status" === "active")
.orderBy($"timestamp") // Breaks stage
.filter($"amount" > 100)
.select($"id", $"amount")
// Better: Group filters before sort
val efficient = df
.filter($"status" === "active")
.filter($"amount" > 100)
.select($"id", $"amount")
.orderBy($"timestamp") // Single stage break at end
// Best: Eliminate unnecessary sorts
val optimal = df
.filter($"status" === "active" && $"amount" > 100)
.select($"id", $"amount")
For complex expressions, ensure they’re code-gen friendly:
// Avoid UDFs - they prevent WSCG
val withUDF = df.withColumn("result", udf((x: Int) => x * 2)($"value"))
// Use built-in expressions instead
val withExpr = df.withColumn("result", $"value" * 2)
// Complex logic using SQL expressions
val complex = df.withColumn("category",
when($"amount" < 100, "small")
.when($"amount" < 1000, "medium")
.otherwise("large")
)
Debugging Generated Code
Access generated code for inspection:
spark.conf.set("spark.sql.codegen.comments", "true")
spark.conf.set("spark.sql.codegen.wholeStage", "true")
val df = spark.range(100)
.filter($"id" % 2 === 0)
.selectExpr("id * 2 as result")
// View generated code
df.queryExecution.debug.codegen()
This outputs the actual Java code Spark generates, useful for understanding performance characteristics and identifying optimization opportunities.
Monitoring Code Generation Metrics
Track WSCG effectiveness through Spark UI and metrics:
val listener = new org.apache.spark.sql.util.QueryExecutionListener {
override def onSuccess(funcName: String, qe: org.apache.spark.sql.execution.QueryExecution, durationNs: Long): Unit = {
val codegenStages = qe.executedPlan.collect {
case w: org.apache.spark.sql.execution.WholeStageCodegenExec => w
}
println(s"Query used ${codegenStages.size} codegen stages")
}
override def onFailure(funcName: String, qe: org.apache.spark.sql.execution.QueryExecution, exception: Exception): Unit = {}
}
spark.listenerManager.register(listener)
Monitor the Spark UI SQL tab for “WholeStageCodegen” metrics including compilation time and generated code size. Large generated code (>8KB) may exceed JVM method size limits, causing fallback to interpreted execution.
Whole-stage code generation represents a fundamental shift in query execution architecture. By understanding its mechanics and constraints, you can design Spark jobs that leverage hardware efficiency while avoiding common pitfalls that force expensive fallback paths.