Apache Spark - Stages and Tasks Explained
Spark's execution model transforms your high-level DataFrame or RDD operations into a directed acyclic graph (DAG) of stages and tasks. When you call an action like `collect()` or `count()`, Spark's...
Key Insights
- Spark breaks jobs into stages at shuffle boundaries (wide transformations), with each stage containing tasks that execute narrow transformations in parallel across partitions
- Each task processes one partition of data and runs on a single executor core, making the number of tasks equal to the number of partitions in each stage
- Understanding stage boundaries and task allocation is critical for optimizing Spark applications, as shuffle operations create network I/O bottlenecks and task skew causes resource underutilization
Understanding Spark’s Execution Model
Spark’s execution model transforms your high-level DataFrame or RDD operations into a directed acyclic graph (DAG) of stages and tasks. When you call an action like collect() or count(), Spark’s DAG scheduler analyzes your transformations and creates an execution plan.
The hierarchy works like this: A job is created for each action. Jobs split into stages at shuffle boundaries. Each stage divides into tasks based on partition count. Tasks are the actual units of work distributed to executors.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("StagesExample").getOrCreate()
# Create sample data
data = [(1, "A", 100), (2, "B", 200), (3, "A", 150), (4, "C", 300)]
df = spark.createDataFrame(data, ["id", "category", "value"])
# This creates ONE job with multiple stages
result = df.groupBy("category") \
.agg({"value": "sum"}) \
.orderBy("category") \
.collect()
This simple operation creates multiple stages because both groupBy and orderBy require shuffles.
Narrow vs Wide Transformations
Transformations determine stage boundaries. Narrow transformations process data within a single partition without moving data across the cluster. Wide transformations require shuffling data between partitions.
Narrow transformations include:
map(),filter(),flatMap()union()(when partitions align)mapPartitions(),mapValues()
Wide transformations include:
groupBy(),reduceByKey()join(),cogroup()distinct(),repartition()sortBy(),orderBy()
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
val spark = SparkSession.builder().appName("Transformations").getOrCreate()
val events = spark.read.parquet("s3://bucket/events/")
// Stage 1: All narrow transformations execute in ONE stage
val filtered = events
.filter(col("event_type") === "click")
.withColumn("timestamp_ms", col("timestamp") * 1000)
.select("user_id", "product_id", "timestamp_ms")
// Stage 2: groupBy forces a shuffle, creating a NEW stage
val aggregated = filtered
.groupBy("user_id")
.agg(
count("product_id").as("click_count"),
max("timestamp_ms").as("last_click")
)
// Stage 3: orderBy creates ANOTHER shuffle and stage
val sorted = aggregated.orderBy(desc("click_count"))
sorted.write.parquet("s3://bucket/results/")
This code creates three stages. The DAG scheduler identifies two shuffle boundaries (groupBy and orderBy) that split the execution.
Stage Composition and Task Execution
Each stage contains tasks equal to the number of input partitions. If you have 200 partitions entering a stage, Spark creates 200 tasks. Each task runs independently on an executor core.
from pyspark.sql.functions import col, count, sum as _sum
# Check current partition count
df = spark.read.parquet("hdfs://data/transactions")
print(f"Partitions: {df.rdd.getNumPartitions()}") # e.g., 128
# Stage 1: 128 tasks (one per partition)
# All narrow transformations execute together
enriched = df \
.filter(col("amount") > 0) \
.withColumn("amount_usd", col("amount") * col("exchange_rate")) \
.select("customer_id", "amount_usd", "category")
# Stage 2: Shuffle reduces to fewer partitions
# Default: spark.sql.shuffle.partitions = 200
# This stage will have 200 tasks
aggregated = enriched \
.groupBy("customer_id", "category") \
.agg(_sum("amount_usd").as("total_spent"))
# Control shuffle partitions explicitly
spark.conf.set("spark.sql.shuffle.partitions", "50")
# Now Stage 2 will have only 50 tasks
aggregated_optimized = enriched \
.groupBy("customer_id", "category") \
.agg(_sum("amount_usd").as("total_spent"))
The spark.sql.shuffle.partitions parameter (default 200) controls partition count after shuffle operations. This directly determines task count in subsequent stages.
Visualizing Stages in Spark UI
The Spark UI (typically at http://driver:4040) shows the execution plan. Each job displays its stages, and clicking a stage reveals task-level metrics.
# Generate a complex job to examine
df1 = spark.range(0, 1000000).withColumn("key", col("id") % 100)
df2 = spark.range(0, 1000000).withColumn("key", col("id") % 100)
# This creates multiple stages:
# Stage 0-1: Read and prepare df1, df2
# Stage 2: Shuffle for join
# Stage 3: Shuffle for aggregation
# Stage 4: Shuffle for sort
result = df1.join(df2, "key") \
.groupBy("key") \
.count() \
.orderBy("count") \
.collect()
# Keep the application running to examine UI
input("Check Spark UI at http://localhost:4040, then press Enter...")
In the UI, you’ll see:
- DAG Visualization: Boxes represent stages, arrows show dependencies
- Task Metrics: Duration, shuffle read/write, GC time per task
- Stage Details: Number of tasks, data size, execution timeline
Optimizing Stage Boundaries
Reducing shuffles improves performance significantly. Shuffle operations write data to disk, transfer it across the network, and read it back.
// INEFFICIENT: Multiple shuffles
val result1 = df
.groupBy("user_id").count() // Shuffle 1
.filter(col("count") > 5)
.join(users, "user_id") // Shuffle 2
.groupBy("region").count() // Shuffle 3
// OPTIMIZED: Filter before shuffle, broadcast small table
val result2 = df
.groupBy("user_id").count()
.filter(col("count") > 5) // Filter reduces data before join
.join(broadcast(users), "user_id") // Broadcast avoids shuffle
.groupBy("region").count()
// FURTHER OPTIMIZED: Pre-aggregate if possible
val result3 = df
.join(broadcast(users), "user_id") // Join first (broadcasted)
.filter(col("activity_count") > 5) // Filter on joined data
.groupBy("region").count() // Single shuffle
Broadcasting small DataFrames (< 10MB typically) eliminates shuffle for joins. Spark sends the entire small DataFrame to each executor.
Task-Level Optimization
Task duration should be balanced across all tasks in a stage. Skewed tasks indicate data imbalance.
from pyspark.sql.functions import col, rand
# Create skewed data (most records have key=1)
skewed_data = spark.range(0, 1000000) \
.withColumn("key", when(col("id") % 100 < 90, lit(1)).otherwise(col("id") % 10))
# This will have one very slow task
skewed_result = skewed_data.groupBy("key").count()
# Solution 1: Salt the key to distribute load
salted_data = skewed_data \
.withColumn("salt", (rand() * 10).cast("int")) \
.withColumn("salted_key", concat(col("key"), lit("_"), col("salt")))
salted_result = salted_data \
.groupBy("salted_key") \
.count() \
.withColumn("original_key", split(col("salted_key"), "_")[0]) \
.groupBy("original_key") \
.agg(_sum("count").as("total_count"))
# Solution 2: Repartition before aggregation
repartitioned = skewed_data.repartition(50, "key")
balanced_result = repartitioned.groupBy("key").count()
Monitor task duration in Spark UI. If max task duration is 10x the median, you have skew problems.
Controlling Parallelism
The number of tasks executing simultaneously depends on available executor cores.
# Configuration affects task parallelism
spark = SparkSession.builder \
.appName("Parallelism") \
.config("spark.executor.instances", "10") \
.config("spark.executor.cores", "4") \
.config("spark.sql.shuffle.partitions", "400") \
.getOrCreate()
# Total parallelism: 10 executors * 4 cores = 40 tasks running concurrently
# If a stage has 400 tasks, they execute in 10 waves (400/40)
# For small datasets, reduce shuffle partitions
spark.conf.set("spark.sql.shuffle.partitions", "40")
# Now tasks execute in 1 wave, reducing overhead
Optimal spark.sql.shuffle.partitions should create tasks that process 100-200MB each and match your total core count multiplied by 2-3x for better utilization.
Debugging Stage Failures
When stages fail, examine task-level errors in Spark UI.
# Common failure: Out of memory during shuffle
try:
large_df = spark.read.parquet("hdfs://data/huge_dataset")
result = large_df.groupBy("key").agg(collect_list("value")) # Dangerous!
except Exception as e:
print(f"Error: {e}")
# Solution: Increase executor memory or reduce data per task
spark.conf.set("spark.executor.memory", "8g")
spark.conf.set("spark.sql.shuffle.partitions", "1000") # More, smaller tasks
# Or avoid collecting large lists
result_safe = large_df.groupBy("key").agg(count("value"), _sum("value"))
Task failures often indicate memory pressure, data skew, or network issues. The Spark UI’s “Failed Tasks” section shows stack traces and which executor failed.
Understanding stages and tasks transforms Spark from a black box into a predictable execution engine. Monitor your jobs, identify shuffle boundaries, and optimize partition counts to build efficient data pipelines.