Spark Streaming - Stateful Processing (mapGroupsWithState)

Structured Streaming's built-in aggregations handle simple cases, but real-world scenarios often require custom state management. Consider session tracking where you need to group events by user,...

Key Insights

  • mapGroupsWithState provides arbitrary stateful processing in Structured Streaming, enabling complex business logic like session tracking, real-time aggregations, and event correlation that updateStateByKey cannot handle
  • State management requires explicit timeout configuration (ProcessingTimeTimeout or EventTimeTimeout) to prevent unbounded state growth and ensure timely state eviction
  • The API demands careful handling of state serialization, iterator management, and output modes—only Update and Append modes work with stateful operations

Understanding Stateful Processing Requirements

Structured Streaming’s built-in aggregations handle simple cases, but real-world scenarios often require custom state management. Consider session tracking where you need to group events by user, maintain session metadata, and expire sessions after inactivity. Or fraud detection systems that correlate transactions across time windows with complex business rules.

mapGroupsWithState gives you direct control over state lifecycle. Unlike the deprecated updateStateByKey from DStreams, it integrates with Structured Streaming’s event-time processing and watermarking, providing exactly-once semantics and better performance.

Basic Implementation Pattern

Start with defining your state and output case classes. State must be serializable and should contain only necessary data to minimize memory overhead.

import org.apache.spark.sql.streaming._
import org.apache.spark.sql.{Dataset, SparkSession}
import java.sql.Timestamp

case class InputEvent(userId: String, eventType: String, timestamp: Timestamp, value: Double)
case class SessionState(
  userId: String,
  startTime: Timestamp,
  lastEventTime: Timestamp,
  eventCount: Int,
  totalValue: Double
)
case class SessionOutput(
  userId: String,
  sessionDuration: Long,
  eventCount: Int,
  totalValue: Double,
  isTimeout: Boolean
)

The stateful function signature requires three parameters: key, iterator of values, and state object. Return an iterator of output records.

def updateSessionState(
  userId: String,
  inputs: Iterator[InputEvent],
  state: GroupState[SessionState]
): Iterator[SessionOutput] = {
  
  // Check if state timed out
  if (state.hasTimedOut) {
    val currentState = state.get
    state.remove()
    return Iterator(SessionOutput(
      userId = currentState.userId,
      sessionDuration = currentState.lastEventTime.getTime - currentState.startTime.getTime,
      eventCount = currentState.eventCount,
      totalValue = currentState.totalValue,
      isTimeout = true
    ))
  }
  
  // Process incoming events
  val events = inputs.toSeq
  val currentState = if (state.exists) state.get else {
    SessionState(userId, events.head.timestamp, events.head.timestamp, 0, 0.0)
  }
  
  val updatedState = events.foldLeft(currentState) { (s, event) =>
    s.copy(
      lastEventTime = if (event.timestamp.after(s.lastEventTime)) event.timestamp else s.lastEventTime,
      eventCount = s.eventCount + 1,
      totalValue = s.totalValue + event.value
    )
  }
  
  // Update state and set timeout
  state.update(updatedState)
  state.setTimeoutDuration("30 minutes")
  
  Iterator.empty // Only output on timeout
}

Applying Stateful Transformations

Wire up the stateful function with your streaming DataFrame. The key is proper grouping before applying mapGroupsWithState.

val spark = SparkSession.builder()
  .appName("StatefulProcessing")
  .master("local[*]")
  .getOrCreate()

import spark.implicits._

val inputStream = spark.readStream
  .format("kafka")
  .option("kafka.bootstrap.servers", "localhost:9092")
  .option("subscribe", "user-events")
  .load()
  .selectExpr("CAST(value AS STRING)")
  .as[String]
  .map(parseEvent) // Parse JSON to InputEvent

val sessionStream = inputStream
  .groupByKey(_.userId)
  .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(updateSessionState)

val query = sessionStream.writeStream
  .outputMode("update")
  .format("console")
  .start()

query.awaitTermination()

Event-Time Based Timeouts

Processing-time timeouts work for simple cases, but event-time timeouts provide correctness guarantees when dealing with late data. This requires watermarking configuration.

def updateSessionWithEventTime(
  userId: String,
  inputs: Iterator[InputEvent],
  state: GroupState[SessionState]
): Iterator[SessionOutput] = {
  
  if (state.hasTimedOut) {
    val currentState = state.get
    state.remove()
    return Iterator(SessionOutput(
      userId = currentState.userId,
      sessionDuration = currentState.lastEventTime.getTime - currentState.startTime.getTime,
      eventCount = currentState.eventCount,
      totalValue = currentState.totalValue,
      isTimeout = true
    ))
  }
  
  val events = inputs.toSeq.sortBy(_.timestamp.getTime)
  val currentState = if (state.exists) state.get else {
    SessionState(userId, events.head.timestamp, events.head.timestamp, 0, 0.0)
  }
  
  val updatedState = events.foldLeft(currentState) { (s, event) =>
    s.copy(
      lastEventTime = event.timestamp,
      eventCount = s.eventCount + 1,
      totalValue = s.totalValue + event.value
    )
  }
  
  state.update(updatedState)
  
  // Set timeout based on last event time plus inactivity period
  val timeoutTimestamp = updatedState.lastEventTime.getTime + (30 * 60 * 1000) // 30 minutes
  state.setTimeoutTimestamp(timeoutTimestamp)
  
  Iterator.empty
}

val sessionStreamEventTime = inputStream
  .withWatermark("timestamp", "10 minutes")
  .groupByKey(_.userId)
  .mapGroupsWithState(GroupStateTimeout.EventTimeTimeout)(updateSessionWithEventTime)

Complex State Management Pattern

Real applications often need multiple state transitions and conditional outputs. Here’s a fraud detection example tracking transaction patterns.

case class Transaction(userId: String, amount: Double, location: String, timestamp: Timestamp)
case class FraudState(
  userId: String,
  recentTransactions: Seq[Transaction],
  riskScore: Double,
  flaggedAt: Option[Timestamp]
)
case class FraudAlert(userId: String, reason: String, riskScore: Double, timestamp: Timestamp)

def detectFraud(
  userId: String,
  inputs: Iterator[Transaction],
  state: GroupState[FraudState]
): Iterator[FraudAlert] = {
  
  val transactions = inputs.toSeq
  val currentState = if (state.exists) state.get else {
    FraudState(userId, Seq.empty, 0.0, None)
  }
  
  // Keep only last 100 transactions
  val allTransactions = (currentState.recentTransactions ++ transactions)
    .sortBy(_.timestamp.getTime)(Ordering[Long].reverse)
    .take(100)
  
  // Calculate risk factors
  val rapidTransactions = allTransactions
    .sliding(2)
    .count { case Seq(t1, t2) => 
      (t1.timestamp.getTime - t2.timestamp.getTime) < 60000 // Within 1 minute
    }
  
  val locationChanges = allTransactions
    .sliding(2)
    .count { case Seq(t1, t2) => t1.location != t2.location }
  
  val highValueCount = transactions.count(_.amount > 1000)
  
  val newRiskScore = (rapidTransactions * 10) + (locationChanges * 5) + (highValueCount * 15)
  
  val updatedState = currentState.copy(
    recentTransactions = allTransactions,
    riskScore = newRiskScore,
    flaggedAt = if (newRiskScore > 50 && currentState.flaggedAt.isEmpty) 
      Some(transactions.maxBy(_.timestamp.getTime).timestamp) 
      else currentState.flaggedAt
  )
  
  state.update(updatedState)
  state.setTimeoutDuration("2 hours")
  
  // Generate alerts if risk threshold exceeded
  if (newRiskScore > 50) {
    val reasons = Seq(
      if (rapidTransactions > 3) Some("Rapid transactions") else None,
      if (locationChanges > 5) Some("Multiple locations") else None,
      if (highValueCount > 2) Some("High-value transactions") else None
    ).flatten.mkString(", ")
    
    Iterator(FraudAlert(userId, reasons, newRiskScore, transactions.head.timestamp))
  } else {
    Iterator.empty
  }
}

State Size Management and Checkpointing

State grows unbounded without proper management. Monitor state metrics and implement cleanup strategies.

val query = sessionStream.writeStream
  .outputMode("update")
  .option("checkpointLocation", "/tmp/checkpoint")
  .foreachBatch { (batchDF: Dataset[SessionOutput], batchId: Long) =>
    // Log state metrics
    println(s"Batch $batchId: ${batchDF.count()} outputs")
    batchDF.write.format("delta").mode("append").save("/data/sessions")
  }
  .start()

// Monitor state store metrics
spark.streams.active.foreach { query =>
  println(s"State store memory: ${query.lastProgress.stateOperators.head.numRowsTotal}")
}

State serialization impacts performance significantly. Use Kryo serialization for complex state objects:

spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
spark.conf.set("spark.kryo.registrationRequired", "true")
spark.conf.set("spark.kryo.classesToRegister", 
  "com.example.SessionState,com.example.FraudState")

mapGroupsWithState enables sophisticated streaming applications that require custom state logic beyond simple aggregations. Proper timeout configuration, state size management, and serialization tuning are critical for production deployments handling high-volume streams.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.