PySpark - Best Practices for Production Code

Production PySpark code deserves the same engineering rigor as any backend service. The days of monolithic notebooks deployed to production should be behind us. Start with a clear project structure:

Key Insights

  • Structure PySpark applications like any production Python project—explicit schemas, configuration management, and comprehensive testing are non-negotiable for maintainable code.
  • Most performance issues stem from unnecessary shuffles and poor join strategies; understanding data distribution and using broadcast joins appropriately can improve job runtime by 10x or more.
  • Treat your Spark jobs as distributed systems: implement proper error handling, structured logging, and monitoring from day one rather than retrofitting after production failures.

Project Structure and Configuration Management

Production PySpark code deserves the same engineering rigor as any backend service. The days of monolithic notebooks deployed to production should be behind us. Start with a clear project structure:

my_spark_project/
├── src/
│   └── my_spark_project/
│       ├── __init__.py
│       ├── config.py
│       ├── jobs/
│       │   ├── __init__.py
│       │   └── daily_aggregation.py
│       ├── transformations/
│       │   ├── __init__.py
│       │   └── user_metrics.py
│       └── schemas/
│           ├── __init__.py
│           └── events.py
├── tests/
├── pyproject.toml
└── spark-submit.sh

Configuration should be externalized and validated at startup. Using Pydantic ensures your job fails fast with clear error messages rather than halfway through processing:

from pydantic import BaseSettings, validator
from typing import Optional

class SparkConfig(BaseSettings):
    app_name: str
    master: str = "yarn"
    executor_memory: str = "4g"
    executor_cores: int = 2
    num_executors: int = 10
    shuffle_partitions: int = 200
    
    # Data paths
    input_path: str
    output_path: str
    checkpoint_dir: Optional[str] = None
    
    @validator('executor_memory')
    def validate_memory(cls, v):
        if not v.endswith(('g', 'm')):
            raise ValueError('Memory must end with g or m')
        return v
    
    class Config:
        env_prefix = "SPARK_"
        env_file = ".env"

def create_spark_session(config: SparkConfig):
    from pyspark.sql import SparkSession
    
    return (SparkSession.builder
        .appName(config.app_name)
        .master(config.master)
        .config("spark.executor.memory", config.executor_memory)
        .config("spark.executor.cores", config.executor_cores)
        .config("spark.sql.shuffle.partitions", config.shuffle_partitions)
        .config("spark.sql.adaptive.enabled", "true")
        .getOrCreate())

Efficient DataFrame Operations

The single biggest performance killer in PySpark is unnecessary data shuffling. Every groupBy, join, and distinct operation potentially moves data across the cluster. Understanding this changes how you write code.

Here’s a common anti-pattern—joining a large fact table with a small dimension table:

# Anti-pattern: Regular join with small table
def get_user_orders_bad(orders_df, users_df):
    return orders_df.join(users_df, "user_id")  # Shuffles both datasets

The optimized version uses broadcast joins when one side is small (typically under 10MB, configurable up to 8GB):

from pyspark.sql.functions import broadcast

def get_user_orders_good(orders_df, users_df):
    # Broadcast the smaller DataFrame to all executors
    return orders_df.join(broadcast(users_df), "user_id")

Column pruning matters more than you think. Spark’s optimizer helps, but being explicit improves readability and ensures optimization:

# Anti-pattern: Select all, filter later
def process_events_bad(events_df):
    return (events_df
        .filter(events_df.event_type == "purchase")
        .groupBy("user_id")
        .count())

# Better: Select only needed columns early
def process_events_good(events_df):
    return (events_df
        .select("user_id", "event_type")  # Prune columns first
        .filter(events_df.event_type == "purchase")
        .groupBy("user_id")
        .count())

Caching is often overused. Only cache when you’re reusing a DataFrame multiple times AND the computation is expensive. Unnecessary caching wastes memory and can cause spills to disk:

def compute_with_reuse(events_df):
    # Good use of cache: expensive computation reused multiple times
    user_metrics = (events_df
        .groupBy("user_id")
        .agg(
            count("*").alias("event_count"),
            sum("revenue").alias("total_revenue")
        )
        .cache())  # Will be used for both outputs
    
    high_value = user_metrics.filter(col("total_revenue") > 1000)
    low_engagement = user_metrics.filter(col("event_count") < 5)
    
    # Don't forget to unpersist when done
    user_metrics.unpersist()
    
    return high_value, low_engagement

Schema Management and Data Validation

Never rely on schema inference in production. It’s slow (requires a full data scan), inconsistent (different files may infer differently), and fragile (empty partitions cause failures). Define schemas explicitly:

from pyspark.sql.types import (
    StructType, StructField, StringType, 
    LongType, TimestampType, DoubleType
)

EVENT_SCHEMA = StructType([
    StructField("event_id", StringType(), nullable=False),
    StructField("user_id", StringType(), nullable=False),
    StructField("event_type", StringType(), nullable=False),
    StructField("timestamp", TimestampType(), nullable=False),
    StructField("properties", StringType(), nullable=True),  # JSON string
    StructField("revenue", DoubleType(), nullable=True),
])

def read_events(spark, path: str):
    return spark.read.schema(EVENT_SCHEMA).parquet(path)

Data validation should happen early in your pipeline. Fail fast on bad data rather than propagating garbage:

from pyspark.sql.functions import col, when, count

class DataValidationError(Exception):
    pass

def validate_events(df, max_null_ratio: float = 0.01):
    """Validate event data quality, raise on failures."""
    total_count = df.count()
    
    if total_count == 0:
        raise DataValidationError("Empty DataFrame received")
    
    # Check null ratios for required fields
    null_counts = df.select([
        count(when(col(c).isNull(), 1)).alias(c)
        for c in ["event_id", "user_id", "event_type"]
    ]).collect()[0]
    
    for field in ["event_id", "user_id", "event_type"]:
        null_ratio = null_counts[field] / total_count
        if null_ratio > max_null_ratio:
            raise DataValidationError(
                f"Field {field} has {null_ratio:.2%} nulls, "
                f"exceeds threshold {max_null_ratio:.2%}"
            )
    
    return df

Error Handling and Logging

Spark’s distributed nature makes error handling tricky. Exceptions on executors don’t automatically propagate cleanly. Configure logging properly with a log4j2.properties file:

# log4j2.properties
rootLogger.level = WARN
rootLogger.appenderRef.stdout.ref = console

appender.console.type = Console
appender.console.name = console
appender.console.layout.type = PatternLayout
appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss} %-5p [%t] %c{1}:%L - %m%n

# Your application's logging
logger.myapp.name = my_spark_project
logger.myapp.level = INFO

For Python logging that works with Spark:

import logging

def get_logger(name: str) -> logging.Logger:
    logger = logging.getLogger(name)
    if not logger.handlers:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        ))
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)
    return logger

# Usage in your job
logger = get_logger(__name__)

def run_job(spark, config):
    logger.info(f"Starting job with config: {config}")
    try:
        df = read_events(spark, config.input_path)
        logger.info(f"Read {df.count()} events")
        # ... processing
    except Exception as e:
        logger.error(f"Job failed: {e}", exc_info=True)
        raise

Testing PySpark Applications

Testing Spark code requires a local SparkSession. Use pytest fixtures to manage session lifecycle:

import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark():
    """Create a SparkSession for testing."""
    spark = (SparkSession.builder
        .master("local[2]")
        .appName("pytest")
        .config("spark.sql.shuffle.partitions", "2")
        .config("spark.ui.enabled", "false")
        .getOrCreate())
    yield spark
    spark.stop()

@pytest.fixture
def sample_events(spark):
    """Create sample event data for testing."""
    data = [
        ("e1", "u1", "purchase", 100.0),
        ("e2", "u1", "view", None),
        ("e3", "u2", "purchase", 50.0),
    ]
    return spark.createDataFrame(
        data, ["event_id", "user_id", "event_type", "revenue"]
    )

Test transformations in isolation:

from my_spark_project.transformations.user_metrics import calculate_user_revenue

def test_calculate_user_revenue(spark, sample_events):
    result = calculate_user_revenue(sample_events)
    
    result_dict = {row.user_id: row.total_revenue for row in result.collect()}
    
    assert result_dict["u1"] == 100.0
    assert result_dict["u2"] == 50.0

Resource Optimization and Monitoring

Tuning Spark resources is part science, part experimentation. Start with these guidelines:

  • Executor memory: 4-8GB is usually optimal. Larger heaps cause GC issues.
  • Executor cores: 2-5 cores per executor. More cores mean more concurrent tasks but also more memory pressure.
  • Parallelism: Set spark.sql.shuffle.partitions to 2-3x your total executor cores.

Use accumulators to track job metrics:

from pyspark import AccumulatorParam

def run_with_monitoring(spark, df):
    records_processed = spark.sparkContext.accumulator(0)
    errors_encountered = spark.sparkContext.accumulator(0)
    
    def process_partition(iterator):
        count = 0
        for row in iterator:
            try:
                # Process row
                count += 1
                yield row
            except Exception:
                errors_encountered.add(1)
        records_processed.add(count)
    
    result = df.rdd.mapPartitions(process_partition).toDF(df.schema)
    result.write.parquet("/output")  # Trigger execution
    
    print(f"Processed: {records_processed.value}, Errors: {errors_encountered.value}")

Deployment Patterns

Package your application properly for cluster submission. A typical spark-submit script:

#!/bin/bash
set -e

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
APP_VERSION="${APP_VERSION:-1.0.0}"

spark-submit \
    --master yarn \
    --deploy-mode cluster \
    --num-executors 20 \
    --executor-memory 8g \
    --executor-cores 4 \
    --driver-memory 4g \
    --conf spark.sql.adaptive.enabled=true \
    --conf spark.sql.adaptive.coalescePartitions.enabled=true \
    --py-files "${SCRIPT_DIR}/dist/my_spark_project-${APP_VERSION}.zip" \
    "${SCRIPT_DIR}/src/my_spark_project/jobs/daily_aggregation.py" \
    --config-path "s3://bucket/configs/prod.yaml"

For containerized deployments, keep your Dockerfile simple:

FROM apache/spark-py:3.5.0

USER root
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY src/ /app/src/
WORKDIR /app

USER spark
ENTRYPOINT ["/opt/spark/bin/spark-submit"]

Production PySpark isn’t about clever tricks—it’s about applying solid software engineering practices to distributed computing. Explicit schemas, proper testing, structured logging, and sensible resource configuration will prevent most production issues before they happen.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.