Spark Scala - Broadcast Variables and Accumulators

When you write a Spark job, closures capture variables from your driver program and serialize them to every task. This works fine for small values, but becomes catastrophic when you're shipping a...

Key Insights

  • Broadcast variables eliminate redundant data shipping by distributing read-only data once to each executor, making them essential for lookup tables and configuration data in distributed joins
  • Accumulators provide a safe mechanism for aggregating metrics across tasks, but their reliability guarantees differ significantly between actions and transformations—use them in actions only for accurate counts
  • Both shared variable types require explicit lifecycle management; failing to unpersist broadcast variables or misunderstanding accumulator retry semantics leads to memory leaks and incorrect metrics

Introduction to Shared Variables in Spark

When you write a Spark job, closures capture variables from your driver program and serialize them to every task. This works fine for small values, but becomes catastrophic when you’re shipping a 500MB lookup table to thousands of tasks across hundreds of executors. Each task gets its own copy, and your cluster grinds to a halt under serialization overhead and memory pressure.

Spark provides two specialized shared variable types to solve distinct problems: broadcast variables for efficiently distributing read-only data, and accumulators for aggregating write-only results back to the driver. Understanding when and how to use each is fundamental to writing performant Spark applications.

Understanding Broadcast Variables

Broadcast variables solve the “ship once, use everywhere” problem. Instead of serializing data with every task closure, Spark distributes the data once per executor using an efficient peer-to-peer protocol. All tasks on that executor then share a single copy.

The canonical use case is enriching transactional data with dimension tables. Consider a streaming job that needs to map country codes to country names:

import org.apache.spark.sql.SparkSession
import org.apache.spark.broadcast.Broadcast

val spark = SparkSession.builder()
  .appName("BroadcastExample")
  .getOrCreate()

// Dimension data - small enough to fit in memory
val countryLookup: Map[String, String] = Map(
  "US" -> "United States",
  "GB" -> "United Kingdom",
  "DE" -> "Germany",
  "JP" -> "Japan"
  // ... thousands more entries
)

// Without broadcast - this Map ships with EVERY task
val transactionsRdd = spark.sparkContext.parallelize(
  Seq(("txn1", "US", 100.0), ("txn2", "GB", 250.0), ("txn3", "DE", 175.0))
)

// Bad: countryLookup captured in closure, serialized per task
val enrichedBad = transactionsRdd.map { case (id, code, amount) =>
  (id, countryLookup.getOrElse(code, "Unknown"), amount)
}

// With broadcast - ships once per executor
val countryBroadcast: Broadcast[Map[String, String]] = 
  spark.sparkContext.broadcast(countryLookup)

// Good: only the broadcast reference ships with tasks
val enrichedGood = transactionsRdd.map { case (id, code, amount) =>
  (id, countryBroadcast.value.getOrElse(code, "Unknown"), amount)
}

enrichedGood.collect().foreach(println)

The performance difference is dramatic. With a 100MB lookup table, 1000 tasks, and 50 executors, the naive approach ships 100GB of data (100MB × 1000 tasks). Broadcasting ships roughly 5GB (100MB × 50 executors), plus the overhead is amortized across task batches.

Broadcast variables also shine for ML model parameters, feature dictionaries, and configuration objects that every task needs to reference.

Broadcast Variable Best Practices

Broadcast variables require careful lifecycle management. They consume executor memory until explicitly released, and improper handling leads to memory leaks in long-running applications.

import org.apache.spark.broadcast.Broadcast

class BroadcastManager(spark: SparkSession) {
  private var currentBroadcast: Option[Broadcast[Map[String, String]]] = None
  
  def updateLookupData(newData: Map[String, String]): Broadcast[Map[String, String]] = {
    // Clean up previous broadcast before creating new one
    currentBroadcast.foreach { bc =>
      bc.unpersist(blocking = true)  // Remove from executor memory
      bc.destroy()                    // Release driver resources
    }
    
    // Create fresh broadcast with updated data
    val newBroadcast = spark.sparkContext.broadcast(newData)
    currentBroadcast = Some(newBroadcast)
    newBroadcast
  }
  
  def cleanup(): Unit = {
    currentBroadcast.foreach { bc =>
      bc.unpersist(blocking = true)
      bc.destroy()
    }
    currentBroadcast = None
  }
}

// Usage in streaming application
val manager = new BroadcastManager(spark)

// Initial load
var lookupBc = manager.updateLookupData(loadDimensionTable())

// Periodic refresh (e.g., every hour)
// lookupBc = manager.updateLookupData(loadDimensionTable())

// Application shutdown
manager.cleanup()

Key rules for broadcast variables:

  1. Keep them immutable. Broadcasting a mutable collection invites race conditions and undefined behavior.
  2. Size appropriately. Broadcast variables should fit comfortably in executor memory. The practical limit is typically 2-8GB depending on your cluster configuration.
  3. Use Kryo serialization. Register broadcast types with Kryo for faster serialization and smaller payloads.
  4. Call unpersist explicitly. Don’t rely on garbage collection for timely cleanup.

Introduction to Accumulators

While broadcast variables push data from driver to executors, accumulators aggregate data from executors back to the driver. They’re write-only from the executor perspective—tasks can only add to them, not read the current value.

Spark provides built-in numeric accumulators for common counting scenarios:

import org.apache.spark.util.LongAccumulator

val spark = SparkSession.builder()
  .appName("AccumulatorExample")
  .getOrCreate()

// Create accumulators for ETL quality metrics
val validRecords: LongAccumulator = spark.sparkContext.longAccumulator("validRecords")
val malformedRecords: LongAccumulator = spark.sparkContext.longAccumulator("malformedRecords")
val nullFieldCount: LongAccumulator = spark.sparkContext.longAccumulator("nullFields")

case class RawRecord(id: String, value: String, timestamp: String)
case class ParsedRecord(id: String, value: Double, timestamp: Long)

def parseRecord(raw: RawRecord): Option[ParsedRecord] = {
  try {
    if (raw.value == null || raw.value.isEmpty) {
      nullFieldCount.add(1)
      None
    } else {
      val parsed = ParsedRecord(
        raw.id,
        raw.value.toDouble,
        raw.timestamp.toLong
      )
      validRecords.add(1)
      Some(parsed)
    }
  } catch {
    case _: NumberFormatException =>
      malformedRecords.add(1)
      None
  }
}

val rawData = spark.sparkContext.parallelize(Seq(
  RawRecord("1", "100.5", "1699900000"),
  RawRecord("2", "invalid", "1699900001"),
  RawRecord("3", null, "1699900002"),
  RawRecord("4", "200.0", "1699900003")
))

val parsed = rawData.flatMap(parseRecord)
val results = parsed.collect()  // Action triggers computation

println(s"Valid records: ${validRecords.value}")
println(s"Malformed records: ${malformedRecords.value}")
println(s"Null fields: ${nullFieldCount.value}")

Custom Accumulators

For aggregations beyond simple counting, extend AccumulatorV2. This requires implementing merge logic for combining partial results from different tasks:

import org.apache.spark.util.AccumulatorV2
import scala.collection.mutable

// Custom accumulator for collecting unique error types
class SetAccumulator extends AccumulatorV2[String, Set[String]] {
  private val _set: mutable.Set[String] = mutable.Set.empty
  
  override def isZero: Boolean = _set.isEmpty
  
  override def copy(): AccumulatorV2[String, Set[String]] = {
    val newAcc = new SetAccumulator
    newAcc._set ++= _set
    newAcc
  }
  
  override def reset(): Unit = _set.clear()
  
  override def add(v: String): Unit = _set += v
  
  override def merge(other: AccumulatorV2[String, Set[String]]): Unit = {
    _set ++= other.value
  }
  
  override def value: Set[String] = _set.toSet
}

// Register and use
val errorTypes = new SetAccumulator
spark.sparkContext.register(errorTypes, "errorTypes")

val logs = spark.sparkContext.parallelize(Seq(
  "ERROR: NullPointerException in module A",
  "ERROR: TimeoutException in module B",
  "INFO: Processing complete",
  "ERROR: NullPointerException in module C",
  "ERROR: OutOfMemoryError in module A"
))

logs.foreach { line =>
  if (line.startsWith("ERROR:")) {
    val errorType = line.split(":")(1).trim.split(" ")(0)
    errorTypes.add(errorType)
  }
}

println(s"Unique error types: ${errorTypes.value}")
// Output: Unique error types: Set(NullPointerException, TimeoutException, OutOfMemoryError)

Accumulator Gotchas and Reliability

Here’s the critical caveat that trips up many developers: accumulators are only guaranteed accurate inside actions, not transformations.

When a task fails and retries, or when Spark re-executes a stage due to executor loss, accumulator updates from the failed attempt aren’t rolled back. In transformations (map, filter, etc.), this leads to double-counting:

val doubleCountRisk: LongAccumulator = spark.sparkContext.longAccumulator("riskyCount")

val data = spark.sparkContext.parallelize(1 to 1000, 10)

// DANGEROUS: accumulator in transformation
val transformed = data.map { x =>
  doubleCountRisk.add(1)  // May count same element multiple times on retry
  x * 2
}

// If a task fails and retries, doubleCountRisk may exceed 1000
transformed.count()

// SAFE: accumulator in action
val safeCount: LongAccumulator = spark.sparkContext.longAccumulator("safeCount")

data.foreach { x =>
  safeCount.add(1)  // foreach is an action - guaranteed accurate
}

println(s"Safe count: ${safeCount.value}")  // Always 1000

The rule is simple: use accumulators in foreach, foreachPartition, or after collecting results. Treat accumulator values from transformations as approximate metrics only.

Practical Patterns and Performance Considerations

Let’s combine both concepts in a realistic ETL pipeline that enriches transaction data while tracking quality metrics:

import org.apache.spark.sql.SparkSession
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.util.LongAccumulator

object TransactionETL {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("TransactionETL")
      .getOrCreate()
    
    // Broadcast dimension data
    val merchantLookup: Map[String, (String, String)] = Map(
      "M001" -> ("Amazon", "Retail"),
      "M002" -> ("Netflix", "Entertainment"),
      "M003" -> ("Uber", "Transportation")
    )
    val merchantBc: Broadcast[Map[String, (String, String)]] = 
      spark.sparkContext.broadcast(merchantLookup)
    
    // Quality metric accumulators
    val totalProcessed = spark.sparkContext.longAccumulator("totalProcessed")
    val enrichedCount = spark.sparkContext.longAccumulator("enrichedCount")
    val unknownMerchants = spark.sparkContext.longAccumulator("unknownMerchants")
    
    case class RawTransaction(txnId: String, merchantId: String, amount: Double)
    case class EnrichedTransaction(
      txnId: String, 
      merchantId: String, 
      merchantName: String,
      category: String,
      amount: Double
    )
    
    val transactions = spark.sparkContext.parallelize(Seq(
      RawTransaction("T1", "M001", 150.00),
      RawTransaction("T2", "M002", 15.99),
      RawTransaction("T3", "M999", 50.00),  // Unknown merchant
      RawTransaction("T4", "M003", 25.50)
    ))
    
    // Process with foreach for accurate accumulator counts
    val results = scala.collection.mutable.ArrayBuffer[EnrichedTransaction]()
    
    transactions.collect().foreach { txn =>
      totalProcessed.add(1)
      
      merchantBc.value.get(txn.merchantId) match {
        case Some((name, category)) =>
          enrichedCount.add(1)
          results += EnrichedTransaction(
            txn.txnId, txn.merchantId, name, category, txn.amount
          )
        case None =>
          unknownMerchants.add(1)
          results += EnrichedTransaction(
            txn.txnId, txn.merchantId, "Unknown", "Uncategorized", txn.amount
          )
      }
    }
    
    // Report metrics
    println(s"ETL Summary:")
    println(s"  Total processed: ${totalProcessed.value}")
    println(s"  Successfully enriched: ${enrichedCount.value}")
    println(s"  Unknown merchants: ${unknownMerchants.value}")
    
    // Cleanup
    merchantBc.unpersist(blocking = true)
    merchantBc.destroy()
    
    spark.stop()
  }
}

When choosing between broadcast joins and regular joins, consider: broadcast when dimension tables are under 1GB and relatively static; use standard joins for larger or frequently changing data. For metrics, prefer dedicated monitoring frameworks (Prometheus, Datadog) for production observability, reserving accumulators for job-specific quality checks.

Broadcast variables and accumulators are fundamental tools in the Spark developer’s toolkit. Master their semantics and limitations, and you’ll write faster, more reliable distributed applications.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.