Apache Spark - DAG (Directed Acyclic Graph) Explained
• Spark's DAG execution model transforms high-level operations into optimized stages of tasks, enabling fault tolerance through lineage tracking and eliminating the need to persist intermediate...
Key Insights
• Spark’s DAG execution model transforms high-level operations into optimized stages of tasks, enabling fault tolerance through lineage tracking and eliminating the need to persist intermediate results to disk like MapReduce • Understanding DAG construction reveals how Spark identifies stage boundaries at shuffle operations (wide transformations), allowing parallel execution within stages and pipelined narrow transformations • The DAG scheduler’s physical execution plan differs significantly from the logical plan developers write, applying critical optimizations like predicate pushdown, projection pruning, and operation fusion
Understanding Spark’s DAG Architecture
Apache Spark builds a Directed Acyclic Graph of computations whenever you execute transformations and actions on RDDs, DataFrames, or Datasets. This graph represents the logical execution plan where nodes are RDDs and edges are transformations applied to data. The “acyclic” property ensures no circular dependencies exist, making the computation deterministic and enabling efficient recovery from failures.
When you submit a Spark job, the driver program constructs this DAG before any actual computation occurs. The DAG scheduler then converts this logical plan into a physical execution plan by dividing it into stages and tasks.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("DAGExample").getOrCreate()
# Read data - creates first RDD in the DAG
users = spark.read.parquet("s3://data/users")
# Transformations - build the DAG (lazy evaluation)
active_users = users.filter(users.status == "active")
user_summary = active_users.groupBy("country").count()
# Action - triggers DAG execution
result = user_summary.collect()
In this example, no computation happens until collect() is called. Spark builds the complete DAG first, optimizes it, then executes.
Narrow vs Wide Transformations
Spark classifies transformations into two categories that fundamentally determine how the DAG is divided into stages:
Narrow transformations operate on single partitions independently. Operations like map(), filter(), and union() don’t require data movement across partitions. These can be pipelined together in a single stage.
Wide transformations require shuffling data across partitions. Operations like groupBy(), reduceByKey(), and join() trigger shuffle operations that create stage boundaries in the DAG.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
val spark = SparkSession.builder().appName("Transformations").getOrCreate()
val transactions = spark.read.parquet("transactions.parquet")
// Narrow transformations - single stage
val filtered = transactions
.filter($"amount" > 100) // Narrow
.withColumn("tax", $"amount" * 0.1) // Narrow
.select("user_id", "amount", "tax") // Narrow
// Wide transformation - creates new stage
val userTotals = filtered
.groupBy("user_id") // Wide - shuffle boundary
.agg(
sum("amount").as("total_amount"),
avg("amount").as("avg_amount")
)
userTotals.write.parquet("output.parquet")
The DAG for this job contains two stages: one for the narrow transformations (filter, withColumn, select) and another for the aggregation after the shuffle.
DAG Stages and Tasks
The DAG scheduler breaks the job into stages at shuffle boundaries. Each stage contains a sequence of narrow transformations that can be computed without data movement. Within each stage, Spark creates tasks—one per partition.
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, count
spark = SparkSession.builder \
.appName("StageExample") \
.config("spark.sql.shuffle.partitions", "200") \
.getOrCreate()
# Stage 1: Read and narrow transformations
orders = spark.read.json("orders.json")
enriched = orders \
.filter(col("status") == "completed") \
.withColumn("total_with_tax", col("total") * 1.2)
# Stage 2: First shuffle for groupBy
daily_stats = enriched.groupBy("order_date").agg(
sum("total_with_tax").alias("revenue"),
count("order_id").alias("order_count")
)
# Stage 3: Second shuffle for join
products = spark.read.json("products.json")
product_stats = enriched.groupBy("product_id").agg(
sum("quantity").alias("total_quantity")
)
# Stage 4: Join creates another shuffle
final_result = product_stats.join(products, "product_id")
final_result.explain() # Shows physical plan with stages
Use explain() to visualize the physical plan. You’ll see exchange operations (shuffles) that mark stage boundaries.
Lineage and Fault Tolerance
The DAG represents the complete lineage of transformations. If a partition is lost due to executor failure, Spark recomputes only that partition by traversing the DAG backward from the lost partition to the source data.
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
val data = spark.sparkContext.parallelize(1 to 1000000, 100)
val step1 = data.map(x => (x % 10, x))
val step2 = step1.reduceByKey(_ + _)
val step3 = step2.map { case (k, v) => (k, v * 2) }
// Check lineage using toDebugString
println(step3.toDebugString)
step3.collect()
The toDebugString output shows the complete lineage chain. If an executor fails during step3 computation, Spark can recompute from the shuffle files written during step2, or if those are lost, from the original data RDD.
DAG Optimization Strategies
Spark’s Catalyst optimizer applies multiple optimization passes to the logical DAG before generating the physical plan:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
spark = SparkSession.builder.appName("Optimization").getOrCreate()
# Inefficient: Multiple passes over data
sales = spark.read.parquet("sales.parquet")
high_value = sales.filter(col("amount") > 1000)
recent = high_value.filter(col("date") >= "2024-01-01")
selected = recent.select("customer_id", "amount", "date")
# Optimized: Combined filters, projection pushdown
optimized = sales \
.filter((col("amount") > 1000) & (col("date") >= "2024-01-01")) \
.select("customer_id", "amount", "date")
# Both produce the same physical plan due to Catalyst optimization
print("Inefficient plan:")
high_value.explain(True)
print("\nOptimized plan:")
optimized.explain(True)
The Catalyst optimizer combines predicates and pushes projections down, but writing optimized code improves readability and intent.
Monitoring DAG Execution
The Spark UI provides detailed DAG visualization and execution metrics. Access it at http://<driver-node>:4040 during job execution.
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder()
.appName("DAGMonitoring")
.config("spark.ui.enabled", "true")
.config("spark.ui.port", "4040")
.getOrCreate()
val largeDataset = spark.read.parquet("large_dataset.parquet")
// Complex DAG with multiple stages
val result = largeDataset
.filter($"value" > 0)
.groupBy("category")
.agg(
sum("value").as("total"),
avg("value").as("average")
)
.join(
spark.read.parquet("reference.parquet"),
"category"
)
.orderBy($"total".desc)
result.write.parquet("output.parquet")
// Keep application alive to inspect UI
Thread.sleep(300000)
The DAG visualization tab shows stage dependencies, task distribution, and execution timeline. Look for skewed stages (one task taking much longer) or excessive shuffle reads indicating optimization opportunities.
Practical DAG Optimization Techniques
Reduce shuffle operations by restructuring transformations:
# Less efficient: Multiple shuffles
df1.groupBy("key").count().join(
df2.groupBy("key").sum("value"), "key"
)
# More efficient: Single shuffle with cogroup
df1.join(df2, "key").groupBy("key").agg(
count(df1["id"]).alias("count"),
sum(df2["value"]).alias("sum_value")
)
Persist intermediate results when reusing RDDs across multiple actions:
expensive_computation = raw_data \
.filter(col("valid") == True) \
.join(reference_data, "id") \
.withColumn("computed_field", complex_udf(col("input")))
# Cache to avoid recomputation in DAG
expensive_computation.cache()
# Multiple actions use cached data
result1 = expensive_computation.filter(col("type") == "A").count()
result2 = expensive_computation.filter(col("type") == "B").count()
Broadcast small datasets to avoid shuffle joins:
import org.apache.spark.sql.functions.broadcast
val largeFacts = spark.read.parquet("facts.parquet")
val smallDimension = spark.read.parquet("dimension.parquet")
// Broadcast join avoids shuffling large dataset
val result = largeFacts.join(
broadcast(smallDimension),
"dimension_id"
)
Understanding DAG construction and execution enables you to write efficient Spark applications that minimize shuffles, maximize parallelism, and leverage Spark’s optimization capabilities. Monitor the Spark UI’s DAG visualization to verify your optimizations translate into the expected physical execution plan.