Spark Streaming - Stateful Processing (mapGroupsWithState)
Structured Streaming's built-in aggregations handle simple cases, but real-world scenarios often require custom state management. Consider session tracking where you need to group events by user,...
Key Insights
- mapGroupsWithState provides arbitrary stateful processing in Structured Streaming, enabling complex business logic like session tracking, real-time aggregations, and event correlation that updateStateByKey cannot handle
- State management requires explicit timeout configuration (ProcessingTimeTimeout or EventTimeTimeout) to prevent unbounded state growth and ensure timely state eviction
- The API demands careful handling of state serialization, iterator management, and output modes—only Update and Append modes work with stateful operations
Understanding Stateful Processing Requirements
Structured Streaming’s built-in aggregations handle simple cases, but real-world scenarios often require custom state management. Consider session tracking where you need to group events by user, maintain session metadata, and expire sessions after inactivity. Or fraud detection systems that correlate transactions across time windows with complex business rules.
mapGroupsWithState gives you direct control over state lifecycle. Unlike the deprecated updateStateByKey from DStreams, it integrates with Structured Streaming’s event-time processing and watermarking, providing exactly-once semantics and better performance.
Basic Implementation Pattern
Start with defining your state and output case classes. State must be serializable and should contain only necessary data to minimize memory overhead.
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.{Dataset, SparkSession}
import java.sql.Timestamp
case class InputEvent(userId: String, eventType: String, timestamp: Timestamp, value: Double)
case class SessionState(
userId: String,
startTime: Timestamp,
lastEventTime: Timestamp,
eventCount: Int,
totalValue: Double
)
case class SessionOutput(
userId: String,
sessionDuration: Long,
eventCount: Int,
totalValue: Double,
isTimeout: Boolean
)
The stateful function signature requires three parameters: key, iterator of values, and state object. Return an iterator of output records.
def updateSessionState(
userId: String,
inputs: Iterator[InputEvent],
state: GroupState[SessionState]
): Iterator[SessionOutput] = {
// Check if state timed out
if (state.hasTimedOut) {
val currentState = state.get
state.remove()
return Iterator(SessionOutput(
userId = currentState.userId,
sessionDuration = currentState.lastEventTime.getTime - currentState.startTime.getTime,
eventCount = currentState.eventCount,
totalValue = currentState.totalValue,
isTimeout = true
))
}
// Process incoming events
val events = inputs.toSeq
val currentState = if (state.exists) state.get else {
SessionState(userId, events.head.timestamp, events.head.timestamp, 0, 0.0)
}
val updatedState = events.foldLeft(currentState) { (s, event) =>
s.copy(
lastEventTime = if (event.timestamp.after(s.lastEventTime)) event.timestamp else s.lastEventTime,
eventCount = s.eventCount + 1,
totalValue = s.totalValue + event.value
)
}
// Update state and set timeout
state.update(updatedState)
state.setTimeoutDuration("30 minutes")
Iterator.empty // Only output on timeout
}
Applying Stateful Transformations
Wire up the stateful function with your streaming DataFrame. The key is proper grouping before applying mapGroupsWithState.
val spark = SparkSession.builder()
.appName("StatefulProcessing")
.master("local[*]")
.getOrCreate()
import spark.implicits._
val inputStream = spark.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "localhost:9092")
.option("subscribe", "user-events")
.load()
.selectExpr("CAST(value AS STRING)")
.as[String]
.map(parseEvent) // Parse JSON to InputEvent
val sessionStream = inputStream
.groupByKey(_.userId)
.mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(updateSessionState)
val query = sessionStream.writeStream
.outputMode("update")
.format("console")
.start()
query.awaitTermination()
Event-Time Based Timeouts
Processing-time timeouts work for simple cases, but event-time timeouts provide correctness guarantees when dealing with late data. This requires watermarking configuration.
def updateSessionWithEventTime(
userId: String,
inputs: Iterator[InputEvent],
state: GroupState[SessionState]
): Iterator[SessionOutput] = {
if (state.hasTimedOut) {
val currentState = state.get
state.remove()
return Iterator(SessionOutput(
userId = currentState.userId,
sessionDuration = currentState.lastEventTime.getTime - currentState.startTime.getTime,
eventCount = currentState.eventCount,
totalValue = currentState.totalValue,
isTimeout = true
))
}
val events = inputs.toSeq.sortBy(_.timestamp.getTime)
val currentState = if (state.exists) state.get else {
SessionState(userId, events.head.timestamp, events.head.timestamp, 0, 0.0)
}
val updatedState = events.foldLeft(currentState) { (s, event) =>
s.copy(
lastEventTime = event.timestamp,
eventCount = s.eventCount + 1,
totalValue = s.totalValue + event.value
)
}
state.update(updatedState)
// Set timeout based on last event time plus inactivity period
val timeoutTimestamp = updatedState.lastEventTime.getTime + (30 * 60 * 1000) // 30 minutes
state.setTimeoutTimestamp(timeoutTimestamp)
Iterator.empty
}
val sessionStreamEventTime = inputStream
.withWatermark("timestamp", "10 minutes")
.groupByKey(_.userId)
.mapGroupsWithState(GroupStateTimeout.EventTimeTimeout)(updateSessionWithEventTime)
Complex State Management Pattern
Real applications often need multiple state transitions and conditional outputs. Here’s a fraud detection example tracking transaction patterns.
case class Transaction(userId: String, amount: Double, location: String, timestamp: Timestamp)
case class FraudState(
userId: String,
recentTransactions: Seq[Transaction],
riskScore: Double,
flaggedAt: Option[Timestamp]
)
case class FraudAlert(userId: String, reason: String, riskScore: Double, timestamp: Timestamp)
def detectFraud(
userId: String,
inputs: Iterator[Transaction],
state: GroupState[FraudState]
): Iterator[FraudAlert] = {
val transactions = inputs.toSeq
val currentState = if (state.exists) state.get else {
FraudState(userId, Seq.empty, 0.0, None)
}
// Keep only last 100 transactions
val allTransactions = (currentState.recentTransactions ++ transactions)
.sortBy(_.timestamp.getTime)(Ordering[Long].reverse)
.take(100)
// Calculate risk factors
val rapidTransactions = allTransactions
.sliding(2)
.count { case Seq(t1, t2) =>
(t1.timestamp.getTime - t2.timestamp.getTime) < 60000 // Within 1 minute
}
val locationChanges = allTransactions
.sliding(2)
.count { case Seq(t1, t2) => t1.location != t2.location }
val highValueCount = transactions.count(_.amount > 1000)
val newRiskScore = (rapidTransactions * 10) + (locationChanges * 5) + (highValueCount * 15)
val updatedState = currentState.copy(
recentTransactions = allTransactions,
riskScore = newRiskScore,
flaggedAt = if (newRiskScore > 50 && currentState.flaggedAt.isEmpty)
Some(transactions.maxBy(_.timestamp.getTime).timestamp)
else currentState.flaggedAt
)
state.update(updatedState)
state.setTimeoutDuration("2 hours")
// Generate alerts if risk threshold exceeded
if (newRiskScore > 50) {
val reasons = Seq(
if (rapidTransactions > 3) Some("Rapid transactions") else None,
if (locationChanges > 5) Some("Multiple locations") else None,
if (highValueCount > 2) Some("High-value transactions") else None
).flatten.mkString(", ")
Iterator(FraudAlert(userId, reasons, newRiskScore, transactions.head.timestamp))
} else {
Iterator.empty
}
}
State Size Management and Checkpointing
State grows unbounded without proper management. Monitor state metrics and implement cleanup strategies.
val query = sessionStream.writeStream
.outputMode("update")
.option("checkpointLocation", "/tmp/checkpoint")
.foreachBatch { (batchDF: Dataset[SessionOutput], batchId: Long) =>
// Log state metrics
println(s"Batch $batchId: ${batchDF.count()} outputs")
batchDF.write.format("delta").mode("append").save("/data/sessions")
}
.start()
// Monitor state store metrics
spark.streams.active.foreach { query =>
println(s"State store memory: ${query.lastProgress.stateOperators.head.numRowsTotal}")
}
State serialization impacts performance significantly. Use Kryo serialization for complex state objects:
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
spark.conf.set("spark.kryo.registrationRequired", "true")
spark.conf.set("spark.kryo.classesToRegister",
"com.example.SessionState,com.example.FraudState")
mapGroupsWithState enables sophisticated streaming applications that require custom state logic beyond simple aggregations. Proper timeout configuration, state size management, and serialization tuning are critical for production deployments handling high-volume streams.