PySpark - Streaming from Socket Source

• PySpark's socket streaming provides a lightweight way to process real-time data streams over TCP connections, ideal for development, testing, and scenarios where you need to integrate with legacy...

Key Insights

• PySpark’s socket streaming provides a lightweight way to process real-time data streams over TCP connections, ideal for development, testing, and scenarios where you need to integrate with legacy systems that communicate via sockets.

• Socket sources are stateless and non-fault-tolerant by design—they don’t support checkpointing or replay capabilities, making them suitable for testing but requiring additional infrastructure for production workloads.

• Structured Streaming with socket sources uses the same DataFrame/Dataset API as batch processing, allowing you to apply transformations, aggregations, and windowing operations with minimal code changes.

Setting Up the Socket Server

Before consuming data with PySpark, you need a socket server that will emit data. For testing purposes, use Python’s built-in socket library to create a simple TCP server that streams data line by line.

import socket
import time
import random

def run_socket_server(host='localhost', port=9999):
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server_socket.bind((host, port))
    server_socket.listen(1)
    
    print(f"Socket server listening on {host}:{port}")
    
    conn, addr = server_socket.accept()
    print(f"Connection from {addr}")
    
    try:
        while True:
            # Simulate streaming sensor data
            temperature = random.uniform(20.0, 30.0)
            humidity = random.uniform(40.0, 60.0)
            timestamp = int(time.time())
            
            message = f"{timestamp},{temperature:.2f},{humidity:.2f}\n"
            conn.send(message.encode('utf-8'))
            time.sleep(1)
    except KeyboardInterrupt:
        print("\nShutting down server")
    finally:
        conn.close()
        server_socket.close()

if __name__ == "__main__":
    run_socket_server()

This server generates synthetic sensor readings every second. Run this in a separate terminal before starting your PySpark streaming job.

Basic Socket Streaming Consumer

Create a PySpark Structured Streaming application that reads from the socket source. The socket format expects text data, with each line treated as a separate record.

from pyspark.sql import SparkSession
from pyspark.sql.functions import split, col

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("SocketStreamingBasic") \
    .master("local[*]") \
    .getOrCreate()

# Set log level to reduce noise
spark.sparkContext.setLogLevel("WARN")

# Read from socket source
lines_df = spark.readStream \
    .format("socket") \
    .option("host", "localhost") \
    .option("port", 9999) \
    .load()

# The socket source provides a single 'value' column with string data
# Parse the CSV format
parsed_df = lines_df.select(
    split(col("value"), ",").getItem(0).cast("long").alias("timestamp"),
    split(col("value"), ",").getItem(1).cast("double").alias("temperature"),
    split(col("value"), ",").getItem(2).cast("double").alias("humidity")
)

# Write to console
query = parsed_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()

query.awaitTermination()

The socket format only provides a value column containing the raw text. You must parse this data yourself using string manipulation functions like split.

Applying Transformations and Aggregations

Structured Streaming supports stateful operations like windowing and aggregations. Here’s how to calculate rolling statistics over tumbling windows:

from pyspark.sql import SparkSession
from pyspark.sql.functions import split, col, window, avg, max, min, count
from pyspark.sql.types import TimestampType

spark = SparkSession.builder \
    .appName("SocketStreamingAggregations") \
    .master("local[*]") \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

lines_df = spark.readStream \
    .format("socket") \
    .option("host", "localhost") \
    .option("port", 9999) \
    .load()

# Parse and convert timestamp to proper timestamp type
parsed_df = lines_df.select(
    col("value").cast("string").alias("raw_value")
).select(
    (col("raw_value").substr(1, 10).cast("long").cast(TimestampType())).alias("event_time"),
    split(col("raw_value"), ",").getItem(1).cast("double").alias("temperature"),
    split(col("raw_value"), ",").getItem(2).cast("double").alias("humidity")
)

# Perform windowed aggregations
windowed_stats = parsed_df \
    .withWatermark("event_time", "10 seconds") \
    .groupBy(window(col("event_time"), "30 seconds", "30 seconds")) \
    .agg(
        avg("temperature").alias("avg_temp"),
        max("temperature").alias("max_temp"),
        min("temperature").alias("min_temp"),
        avg("humidity").alias("avg_humidity"),
        count("*").alias("record_count")
    )

query = windowed_stats.writeStream \
    .outputMode("update") \
    .format("console") \
    .option("truncate", False) \
    .start()

query.awaitTermination()

The withWatermark function handles late-arriving data by defining how long to wait for delayed events. The update output mode emits only changed aggregation results.

Multiple Output Sinks and Query Management

Production applications often need to write processed data to multiple destinations. PySpark supports running multiple streaming queries simultaneously:

from pyspark.sql import SparkSession
from pyspark.sql.functions import split, col, current_timestamp

spark = SparkSession.builder \
    .appName("SocketStreamingMultipleSinks") \
    .master("local[*]") \
    .config("spark.sql.streaming.checkpointLocation", "/tmp/checkpoint") \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

lines_df = spark.readStream \
    .format("socket") \
    .option("host", "localhost") \
    .option("port", 9999) \
    .load()

parsed_df = lines_df.select(
    split(col("value"), ",").getItem(0).cast("long").alias("timestamp"),
    split(col("value"), ",").getItem(1).cast("double").alias("temperature"),
    split(col("value"), ",").getItem(2).cast("double").alias("humidity"),
    current_timestamp().alias("processing_time")
)

# Filter for high temperature alerts
alerts_df = parsed_df.filter(col("temperature") > 28.0)

# Console output for all data
query_all = parsed_df.writeStream \
    .outputMode("append") \
    .format("console") \
    .queryName("all_data") \
    .start()

# Memory sink for alerts (queryable table)
query_alerts = alerts_df.writeStream \
    .outputMode("append") \
    .format("memory") \
    .queryName("high_temp_alerts") \
    .start()

# File sink for archival
query_archive = parsed_df.writeStream \
    .outputMode("append") \
    .format("parquet") \
    .option("path", "/tmp/sensor_archive") \
    .option("checkpointLocation", "/tmp/checkpoint/archive") \
    .start()

# Query the in-memory table
import time
time.sleep(10)  # Let some data accumulate
spark.sql("SELECT * FROM high_temp_alerts").show()

# Wait for all queries
spark.streams.awaitAnyTermination()

The memory sink creates a queryable table you can access with SQL. This is useful for debugging or serving real-time dashboards.

Error Handling and Connection Management

Socket sources don’t automatically reconnect on failure. Implement retry logic and graceful shutdown:

from pyspark.sql import SparkSession
import time
import sys

def create_streaming_query(max_retries=3):
    retry_count = 0
    
    while retry_count < max_retries:
        try:
            spark = SparkSession.builder \
                .appName("SocketStreamingRobust") \
                .master("local[*]") \
                .getOrCreate()
            
            spark.sparkContext.setLogLevel("WARN")
            
            lines_df = spark.readStream \
                .format("socket") \
                .option("host", "localhost") \
                .option("port", 9999) \
                .option("includeTimestamp", "true") \
                .load()
            
            query = lines_df.writeStream \
                .outputMode("append") \
                .format("console") \
                .start()
            
            query.awaitTermination()
            break
            
        except Exception as e:
            retry_count += 1
            print(f"Error occurred: {e}")
            print(f"Retry attempt {retry_count}/{max_retries}")
            
            if retry_count < max_retries:
                time.sleep(5)
            else:
                print("Max retries reached. Exiting.")
                sys.exit(1)

if __name__ == "__main__":
    try:
        create_streaming_query()
    except KeyboardInterrupt:
        print("\nShutdown requested")
        sys.exit(0)

Performance Considerations

Socket sources have inherent limitations. They process data from a single TCP connection, creating a bottleneck. For production workloads:

  • Use socket sources only for testing or low-throughput scenarios
  • Consider Kafka, Kinesis, or cloud-native streaming services for production
  • Monitor the inputRowsPerSecond and processedRowsPerSecond metrics
  • Adjust maxOffsetsPerTrigger if you’re testing with high-volume socket streams

The socket source lacks fault tolerance. If your PySpark application crashes, you lose data that was in flight. For critical applications, use sources that support checkpointing and replay mechanisms.

Socket streaming remains valuable for rapid prototyping, integration testing, and scenarios where you need to ingest data from legacy systems that only support TCP socket communication. Understanding its capabilities and limitations helps you choose the right tool for your streaming architecture.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.