PySpark Structured Streaming Tutorial

PySpark Structured Streaming requires Spark 2.0 or later. Install PySpark and create a SparkSession configured for streaming:

Key Insights

  • Structured Streaming in PySpark processes unbounded data using the same DataFrame API as batch processing, making it easier to build real-time pipelines without learning new paradigms
  • The framework provides exactly-once processing semantics through checkpointing and write-ahead logs, ensuring data consistency even during failures
  • Window operations and watermarking enable handling of late-arriving data while preventing unbounded state growth in long-running streaming applications

Setting Up Your Streaming Environment

PySpark Structured Streaming requires Spark 2.0 or later. Install PySpark and create a SparkSession configured for streaming:

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

spark = SparkSession.builder \
    .appName("StructuredStreamingApp") \
    .config("spark.sql.shuffle.partitions", "2") \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

The reduced shuffle partitions configuration improves performance for local development. Production environments typically use higher values based on cluster resources.

Creating a Streaming DataFrame from Socket Source

The simplest streaming source is a socket connection. This example reads text data from a TCP socket:

lines = spark.readStream \
    .format("socket") \
    .option("host", "localhost") \
    .option("port", 9999) \
    .load()

# Process the streaming data
word_counts = lines.select(
    explode(split(col("value"), " ")).alias("word")
).groupBy("word").count()

# Start the streaming query
query = word_counts.writeStream \
    .outputMode("complete") \
    .format("console") \
    .start()

query.awaitTermination()

To test this, run nc -lk 9999 in a terminal and type words. The streaming application will count words in real-time.

Working with File Sources

File-based streaming is production-ready and monitors directories for new files:

# Define schema explicitly for better performance
user_schema = StructType([
    StructField("timestamp", TimestampType(), True),
    StructField("user_id", IntegerType(), True),
    StructField("action", StringType(), True),
    StructField("value", DoubleType(), True)
])

# Read JSON files from directory
streaming_df = spark.readStream \
    .schema(user_schema) \
    .option("maxFilesPerTrigger", 1) \
    .json("/path/to/input/directory")

# Process and write to output
query = streaming_df \
    .filter(col("value") > 100) \
    .writeStream \
    .format("parquet") \
    .option("path", "/path/to/output") \
    .option("checkpointLocation", "/path/to/checkpoint") \
    .start()

The checkpoint location is critical - it stores metadata to enable fault tolerance and exactly-once processing.

Implementing Windowed Aggregations

Time-based windows aggregate events within specific time intervals:

from pyspark.sql.functions import window

# Sample streaming data with timestamps
events = spark.readStream \
    .schema(user_schema) \
    .json("/path/to/events")

# 10-minute tumbling window
windowed_counts = events \
    .groupBy(
        window(col("timestamp"), "10 minutes"),
        col("action")
    ) \
    .agg(
        count("*").alias("event_count"),
        avg("value").alias("avg_value")
    )

query = windowed_counts.writeStream \
    .outputMode("update") \
    .format("console") \
    .option("truncate", "false") \
    .start()

Sliding windows overlap by specifying a slide duration:

# 10-minute window sliding every 5 minutes
sliding_window = events \
    .groupBy(
        window(col("timestamp"), "10 minutes", "5 minutes"),
        col("action")
    ) \
    .count()

Handling Late Data with Watermarking

Watermarking defines how long to wait for late events before finalizing aggregations:

# Define watermark of 10 minutes
events_with_watermark = events \
    .withWatermark("timestamp", "10 minutes")

# Aggregate with watermark
late_data_handled = events_with_watermark \
    .groupBy(
        window(col("timestamp"), "5 minutes"),
        col("user_id")
    ) \
    .agg(
        sum("value").alias("total_value")
    )

query = late_data_handled.writeStream \
    .outputMode("append") \
    .format("console") \
    .start()

Events arriving more than 10 minutes late are dropped. This prevents state from growing indefinitely in long-running streams.

Kafka Integration for Production Streaming

Kafka is the standard message broker for production streaming applications:

# Read from Kafka
kafka_stream = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "user-events") \
    .option("startingOffsets", "latest") \
    .load()

# Kafka messages are binary - parse the value
parsed_stream = kafka_stream.select(
    col("key").cast("string"),
    from_json(
        col("value").cast("string"),
        user_schema
    ).alias("data")
).select("key", "data.*")

# Process and write back to Kafka
processed = parsed_stream \
    .filter(col("value") > 50) \
    .select(
        col("user_id").cast("string").alias("key"),
        to_json(struct("*")).alias("value")
    )

query = processed.writeStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("topic", "processed-events") \
    .option("checkpointLocation", "/path/to/checkpoint") \
    .start()

Stream-Stream Joins

Join two streaming DataFrames based on time windows:

# First stream: user impressions
impressions = spark.readStream \
    .schema(impression_schema) \
    .json("/path/to/impressions")

# Second stream: user clicks
clicks = spark.readStream \
    .schema(click_schema) \
    .json("/path/to/clicks")

# Join with watermarks and time constraints
impressions_with_wm = impressions.withWatermark("impression_time", "2 hours")
clicks_with_wm = clicks.withWatermark("click_time", "3 hours")

joined = impressions_with_wm.join(
    clicks_with_wm,
    expr("""
        user_id = click_user_id AND
        click_time >= impression_time AND
        click_time <= impression_time + interval 1 hour
    """)
)

query = joined.writeStream \
    .outputMode("append") \
    .format("console") \
    .start()

Both streams require watermarks for stateful joins to prevent unbounded state growth.

Output Modes and Sinks

Three output modes control what gets written:

# Complete: entire result table (only for aggregations)
complete_query = word_counts.writeStream \
    .outputMode("complete") \
    .format("memory") \
    .queryName("word_counts_table") \
    .start()

# Append: only new rows (default, for non-aggregations)
append_query = filtered_events.writeStream \
    .outputMode("append") \
    .format("parquet") \
    .option("path", "/output") \
    .start()

# Update: only updated rows (for aggregations)
update_query = aggregated_events.writeStream \
    .outputMode("update") \
    .format("console") \
    .start()

Monitoring and Managing Queries

Access streaming query metrics programmatically:

# Get active streams
active_queries = spark.streams.active

for query in active_queries:
    print(f"Query ID: {query.id}")
    print(f"Status: {query.status}")
    print(f"Recent Progress: {query.recentProgress}")

# Wait for specific query
query.awaitTermination(timeout=60)

# Stop query gracefully
query.stop()

# Exception handling
try:
    query.awaitTermination()
except Exception as e:
    print(f"Stream failed: {e}")
    query.stop()

Implementing Stateful Processing with mapGroupsWithState

For complex stateful logic beyond built-in aggregations:

from pyspark.sql.streaming import GroupState, GroupStateTimeout

def update_user_session(user_id, events, state):
    if state.hasTimedOut:
        return (user_id, state.get, "expired")
    
    current_count = state.get if state.exists else 0
    new_count = current_count + len(events)
    
    state.update(new_count)
    state.setTimeoutDuration(60000)  # 1 minute timeout
    
    return (user_id, new_count, "active")

# Apply stateful function
stateful_stream = events \
    .groupByKey(lambda x: x.user_id) \
    .mapGroupsWithState(
        update_user_session,
        GroupStateTimeout.ProcessingTimeTimeout
    )

This pattern enables session tracking, custom aggregations, and complex event processing that standard SQL operations cannot express.

Performance Optimization Strategies

Tune trigger intervals to balance latency and throughput:

# Micro-batch every 30 seconds
query = processed_stream.writeStream \
    .trigger(processingTime='30 seconds') \
    .format("parquet") \
    .start()

# Continuous processing (experimental, low latency)
query = processed_stream.writeStream \
    .trigger(continuous='1 second') \
    .format("kafka") \
    .start()

# Process available data then stop
query = processed_stream.writeStream \
    .trigger(once=True) \
    .format("parquet") \
    .start()

Partition your output data appropriately:

query = processed_stream.writeStream \
    .partitionBy("date", "hour") \
    .format("parquet") \
    .option("path", "/output") \
    .start()

Structured Streaming transforms real-time data processing by unifying batch and streaming APIs. The checkpoint mechanism ensures reliability, while watermarking and windowing provide precise control over late data and state management. Production deployments should monitor query progress, implement proper error handling, and tune trigger intervals based on latency requirements.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.