PySpark - Read from Kafka with Structured Streaming
PySpark's Structured Streaming API treats Kafka as a structured data source, enabling you to read from topics using the familiar DataFrame API. The basic connection requires the Kafka bootstrap...
Key Insights
- PySpark Structured Streaming provides a declarative API for consuming Kafka messages with exactly-once semantics and automatic offset management through checkpointing
- The
kafkaformat in Structured Streaming treats each Kafka message as a row with fixed schema including key, value, topic, partition, offset, and timestamp columns - Production deployments require careful configuration of trigger intervals, checkpoint locations, and deserialization strategies to balance throughput with fault tolerance
Connecting to Kafka with Structured Streaming
PySpark’s Structured Streaming API treats Kafka as a structured data source, enabling you to read from topics using the familiar DataFrame API. The basic connection requires the Kafka bootstrap servers and topic names.
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("KafkaStructuredStreaming") \
.config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0") \
.getOrCreate()
df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "user-events") \
.load()
df.printSchema()
The Kafka source produces a DataFrame with this schema:
root
|-- key: binary (nullable = true)
|-- value: binary (nullable = true)
|-- topic: string (nullable = true)
|-- partition: integer (nullable = true)
|-- offset: long (nullable = true)
|-- timestamp: timestamp (nullable = true)
|-- timestampType: integer (nullable = true)
Deserializing Kafka Messages
Kafka stores message keys and values as byte arrays. You need to deserialize them based on your data format. For JSON messages, cast the value column to string and parse it.
from pyspark.sql.functions import col, from_json
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType
# Define your message schema
message_schema = StructType([
StructField("user_id", StringType(), True),
StructField("event_type", StringType(), True),
StructField("product_id", StringType(), True),
StructField("quantity", IntegerType(), True),
StructField("event_time", TimestampType(), True)
])
# Deserialize and parse JSON
parsed_df = df.select(
col("key").cast("string").alias("message_key"),
from_json(col("value").cast("string"), message_schema).alias("data"),
col("topic"),
col("partition"),
col("offset"),
col("timestamp")
).select("message_key", "data.*", "topic", "partition", "offset", "timestamp")
parsed_df.printSchema()
For Avro-serialized messages with Schema Registry:
from pyspark.sql.avro.functions import from_avro
# Avro schema as JSON string
avro_schema = """
{
"type": "record",
"name": "UserEvent",
"fields": [
{"name": "user_id", "type": "string"},
{"name": "event_type", "type": "string"},
{"name": "product_id", "type": "string"},
{"name": "quantity", "type": "int"}
]
}
"""
avro_df = df.select(
from_avro(col("value"), avro_schema).alias("data")
).select("data.*")
Subscribing to Multiple Topics
Structured Streaming supports three subscription patterns: specific topics, topic patterns, and topic assignments to specific partitions.
# Subscribe to multiple specific topics
multi_topic_df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "topic1,topic2,topic3") \
.load()
# Subscribe using pattern matching
pattern_df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribePattern", "user-events-.*") \
.load()
# Subscribe to specific partitions
partition_df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("assign", """{"user-events":[0,1,2]}""") \
.load()
Configuring Starting Offsets
Control where to start reading from Kafka topics using the startingOffsets option. This is critical for reprocessing historical data or starting fresh.
# Start from earliest available offset
earliest_df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "user-events") \
.option("startingOffsets", "earliest") \
.load()
# Start from latest offset (default)
latest_df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "user-events") \
.option("startingOffsets", "latest") \
.load()
# Start from specific offsets per partition
specific_offsets = """{"user-events":{"0":23,"1":45,"2":67}}"""
specific_df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "user-events") \
.option("startingOffsets", specific_offsets) \
.load()
Processing Kafka Streams with Transformations
Apply standard DataFrame transformations to process streaming data. This example filters events, aggregates by user, and calculates running totals.
from pyspark.sql.functions import window, sum as _sum, count
# Parse and filter messages
events_df = df.select(
from_json(col("value").cast("string"), message_schema).alias("data"),
col("timestamp")
).select("data.*", "timestamp")
filtered_df = events_df.filter(col("event_type") == "purchase")
# Windowed aggregation
windowed_aggregates = filtered_df \
.withWatermark("event_time", "10 minutes") \
.groupBy(
window(col("event_time"), "5 minutes", "1 minute"),
col("user_id")
) \
.agg(
_sum("quantity").alias("total_quantity"),
count("*").alias("event_count")
)
Writing Query Results with Checkpointing
Checkpointing ensures exactly-once processing semantics by tracking processed offsets. Always specify a checkpoint location in production.
# Console output for debugging
console_query = parsed_df.writeStream \
.outputMode("append") \
.format("console") \
.option("truncate", False) \
.option("numRows", 20) \
.start()
# Write to Parquet with checkpointing
parquet_query = parsed_df.writeStream \
.outputMode("append") \
.format("parquet") \
.option("path", "/data/output/user-events") \
.option("checkpointLocation", "/data/checkpoints/user-events") \
.trigger(processingTime="30 seconds") \
.start()
# Write aggregated results to Delta Lake
windowed_aggregates.writeStream \
.outputMode("update") \
.format("delta") \
.option("path", "/data/delta/user-aggregates") \
.option("checkpointLocation", "/data/checkpoints/user-aggregates") \
.trigger(processingTime="1 minute") \
.start()
parquet_query.awaitTermination()
Production Configuration
Configure Kafka consumer properties and Spark settings for production workloads. These settings impact throughput, fault tolerance, and resource utilization.
production_df = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "broker1:9092,broker2:9092,broker3:9092") \
.option("subscribe", "user-events") \
.option("startingOffsets", "latest") \
.option("maxOffsetsPerTrigger", 100000) \
.option("minPartitions", 10) \
.option("kafka.session.timeout.ms", "30000") \
.option("kafka.request.timeout.ms", "40000") \
.option("kafka.max.poll.records", "1000") \
.option("kafka.max.poll.interval.ms", "600000") \
.option("failOnDataLoss", "false") \
.load()
# Configure Spark streaming settings
spark.conf.set("spark.sql.streaming.schemaInference", "false")
spark.conf.set("spark.sql.streaming.metricsEnabled", "true")
spark.conf.set("spark.sql.adaptive.enabled", "false")
query = production_df \
.select(from_json(col("value").cast("string"), message_schema).alias("data")) \
.select("data.*") \
.writeStream \
.outputMode("append") \
.format("parquet") \
.option("path", "/data/output/events") \
.option("checkpointLocation", "/data/checkpoints/events") \
.trigger(processingTime="10 seconds") \
.start()
Monitoring and Error Handling
Monitor streaming queries and handle failures gracefully. Access query progress and metrics programmatically.
# Get query status and metrics
status = query.status
print(f"Query ID: {query.id}")
print(f"Is Active: {query.isActive}")
print(f"Status Message: {status['message']}")
# Access recent progress
recent_progress = query.recentProgress
for progress in recent_progress:
print(f"Batch: {progress['batchId']}")
print(f"Input Rows: {progress['numInputRows']}")
print(f"Processing Rate: {progress['processedRowsPerSecond']}")
# Handle exceptions with try-except
try:
query.awaitTermination(timeout=3600)
except Exception as e:
print(f"Query failed: {str(e)}")
query.stop()
The failOnDataLoss option controls behavior when Kafka data is unavailable. Set to false in production when topic retention might expire before processing, but understand this trades correctness for availability. Always monitor offset lag and processing rates to ensure your streaming application keeps pace with incoming data.