PySpark - Streaming Join with Static DataFrame

Stream-static joins combine a streaming DataFrame with a static (batch) DataFrame. This pattern is essential when enriching streaming events with reference data like user profiles, product catalogs,...

Key Insights

  • Stream-static joins in PySpark enable real-time enrichment of streaming data with reference data, avoiding expensive stream-stream joins when one dataset changes infrequently
  • Static DataFrames are broadcast to executors once and reused across micro-batches, significantly reducing shuffle operations and improving throughput compared to repeated lookups
  • Proper handling of schema evolution and data refresh patterns is critical—use forEachBatch for dynamic static DataFrame reloading or consider Delta Lake for automatic version management

Understanding Stream-Static Join Architecture

Stream-static joins combine a streaming DataFrame with a static (batch) DataFrame. This pattern is essential when enriching streaming events with reference data like user profiles, product catalogs, or configuration tables that change infrequently.

Unlike stream-stream joins that require watermarking and state management, stream-static joins are stateless operations. PySpark broadcasts the static DataFrame to all executors once per micro-batch, making the join operation efficient and predictable.

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, current_timestamp
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType

spark = SparkSession.builder \
    .appName("StreamStaticJoin") \
    .config("spark.sql.shuffle.partitions", "4") \
    .getOrCreate()

# Define schema for streaming data
streaming_schema = StructType([
    StructField("transaction_id", StringType(), False),
    StructField("user_id", StringType(), False),
    StructField("amount", IntegerType(), False),
    StructField("timestamp", TimestampType(), False)
])

# Create streaming DataFrame from Kafka
streaming_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "transactions") \
    .load() \
    .selectExpr("CAST(value AS STRING) as json_data") \
    .select(from_json(col("json_data"), streaming_schema).alias("data")) \
    .select("data.*")

Building the Static Reference DataFrame

The static DataFrame contains reference data that enriches streaming records. Load this data from databases, files, or data warehouses. The key consideration is ensuring the data fits comfortably in memory across your cluster since it will be broadcast.

# Static user profile data
static_schema = StructType([
    StructField("user_id", StringType(), False),
    StructField("user_name", StringType(), False),
    StructField("tier", StringType(), False),
    StructField("region", StringType(), False),
    StructField("risk_score", IntegerType(), False)
])

static_df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/userdb") \
    .option("dbtable", "user_profiles") \
    .option("user", "admin") \
    .option("password", "password") \
    .load()

# Alternative: Load from parquet for better performance
# static_df = spark.read.parquet("s3://bucket/user_profiles/")

# Cache the static DataFrame to avoid repeated reads
static_df.cache()
static_df.count()  # Materialize the cache

Implementing the Join Operation

The join operation between streaming and static DataFrames follows standard DataFrame join syntax. PySpark automatically detects the static DataFrame and optimizes the execution plan accordingly.

# Perform inner join
enriched_stream = streaming_df.join(
    static_df,
    streaming_df.user_id == static_df.user_id,
    "inner"
).select(
    streaming_df.transaction_id,
    streaming_df.user_id,
    streaming_df.amount,
    streaming_df.timestamp,
    static_df.user_name,
    static_df.tier,
    static_df.region,
    static_df.risk_score
)

# Write enriched stream to output
query = enriched_stream.writeStream \
    .format("parquet") \
    .option("path", "s3://bucket/enriched_transactions/") \
    .option("checkpointLocation", "s3://bucket/checkpoints/enriched/") \
    .outputMode("append") \
    .start()

For left outer joins where you want to preserve all streaming records even without matching reference data:

enriched_stream = streaming_df.join(
    static_df,
    streaming_df.user_id == static_df.user_id,
    "left_outer"
).select(
    streaming_df.transaction_id,
    streaming_df.user_id,
    streaming_df.amount,
    streaming_df.timestamp,
    static_df.user_name.alias("user_name"),
    static_df.tier.alias("tier"),
    coalesce(static_df.risk_score, lit(0)).alias("risk_score")
)

Refreshing Static Data During Stream Processing

Static DataFrames remain constant throughout the streaming query lifecycle by default. For reference data that updates periodically, use forEachBatch to reload the static DataFrame for each micro-batch.

def process_batch(batch_df, batch_id):
    # Reload static DataFrame for this batch
    current_static_df = spark.read \
        .format("jdbc") \
        .option("url", "jdbc:postgresql://localhost:5432/userdb") \
        .option("dbtable", "user_profiles") \
        .option("user", "admin") \
        .option("password", "password") \
        .load()
    
    # Perform join with fresh data
    enriched_batch = batch_df.join(
        current_static_df,
        batch_df.user_id == current_static_df.user_id,
        "inner"
    ).select(
        batch_df.transaction_id,
        batch_df.user_id,
        batch_df.amount,
        batch_df.timestamp,
        current_static_df.user_name,
        current_static_df.tier,
        current_static_df.risk_score
    )
    
    # Write to output
    enriched_batch.write \
        .format("parquet") \
        .mode("append") \
        .save("s3://bucket/enriched_transactions/")

# Apply forEachBatch
query = streaming_df.writeStream \
    .foreachBatch(process_batch) \
    .option("checkpointLocation", "s3://bucket/checkpoints/enriched/") \
    .start()

Performance Optimization Strategies

Broadcast joins are automatic for small static DataFrames, but you can force broadcasting to ensure optimal performance:

from pyspark.sql.functions import broadcast

enriched_stream = streaming_df.join(
    broadcast(static_df),
    streaming_df.user_id == static_df.user_id,
    "inner"
)

Monitor broadcast size limits. The default spark.sql.autoBroadcastJoinThreshold is 10MB. Adjust based on your cluster memory:

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50MB")

For larger static datasets, partition both DataFrames on the join key:

# Repartition streaming data
streaming_df_partitioned = streaming_df.repartition(col("user_id"))

# Repartition static data
static_df_partitioned = static_df.repartition(col("user_id"))

enriched_stream = streaming_df_partitioned.join(
    static_df_partitioned,
    "user_id",
    "inner"
)

Handling Schema Evolution

Schema mismatches between streaming and static data cause runtime failures. Implement schema validation and graceful degradation:

from pyspark.sql.functions import when

def safe_join_with_schema_check(streaming_df, static_df, join_key):
    # Verify join key exists in both DataFrames
    if join_key not in streaming_df.columns:
        raise ValueError(f"Join key {join_key} missing in streaming DataFrame")
    if join_key not in static_df.columns:
        raise ValueError(f"Join key {join_key} missing in static DataFrame")
    
    # Perform join with null handling
    joined = streaming_df.join(
        static_df,
        streaming_df[join_key] == static_df[join_key],
        "left_outer"
    )
    
    # Add default values for missing static fields
    for field in static_df.schema.fields:
        if field.name != join_key:
            if field.name not in joined.columns:
                joined = joined.withColumn(field.name, lit(None).cast(field.dataType))
    
    return joined

enriched_stream = safe_join_with_schema_check(streaming_df, static_df, "user_id")

Production Monitoring and Debugging

Track join performance and data quality metrics:

from pyspark.sql.functions import count, sum as _sum

def process_with_metrics(batch_df, batch_id):
    static_df = load_static_data()  # Your loading logic
    
    # Track input counts
    input_count = batch_df.count()
    
    # Perform join
    joined = batch_df.join(static_df, "user_id", "left_outer")
    
    # Calculate metrics
    matched_count = joined.filter(col("user_name").isNotNull()).count()
    unmatched_count = input_count - matched_count
    
    # Log metrics
    print(f"Batch {batch_id}: Input={input_count}, Matched={matched_count}, Unmatched={unmatched_count}")
    
    # Write results
    joined.write.format("parquet").mode("append").save("output_path")

query = streaming_df.writeStream \
    .foreachBatch(process_with_metrics) \
    .option("checkpointLocation", "checkpoint_path") \
    .start()

Stream-static joins provide a powerful pattern for real-time data enrichment without the complexity of stateful stream-stream joins. Choose the refresh strategy based on your reference data update frequency and latency requirements.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.