PySpark - Best Practices for Production Code
Production PySpark code deserves the same engineering rigor as any backend service. The days of monolithic notebooks deployed to production should be behind us. Start with a clear project structure:
Key Insights
- Structure PySpark applications like any production Python project—explicit schemas, configuration management, and comprehensive testing are non-negotiable for maintainable code.
- Most performance issues stem from unnecessary shuffles and poor join strategies; understanding data distribution and using broadcast joins appropriately can improve job runtime by 10x or more.
- Treat your Spark jobs as distributed systems: implement proper error handling, structured logging, and monitoring from day one rather than retrofitting after production failures.
Project Structure and Configuration Management
Production PySpark code deserves the same engineering rigor as any backend service. The days of monolithic notebooks deployed to production should be behind us. Start with a clear project structure:
my_spark_project/
├── src/
│ └── my_spark_project/
│ ├── __init__.py
│ ├── config.py
│ ├── jobs/
│ │ ├── __init__.py
│ │ └── daily_aggregation.py
│ ├── transformations/
│ │ ├── __init__.py
│ │ └── user_metrics.py
│ └── schemas/
│ ├── __init__.py
│ └── events.py
├── tests/
├── pyproject.toml
└── spark-submit.sh
Configuration should be externalized and validated at startup. Using Pydantic ensures your job fails fast with clear error messages rather than halfway through processing:
from pydantic import BaseSettings, validator
from typing import Optional
class SparkConfig(BaseSettings):
app_name: str
master: str = "yarn"
executor_memory: str = "4g"
executor_cores: int = 2
num_executors: int = 10
shuffle_partitions: int = 200
# Data paths
input_path: str
output_path: str
checkpoint_dir: Optional[str] = None
@validator('executor_memory')
def validate_memory(cls, v):
if not v.endswith(('g', 'm')):
raise ValueError('Memory must end with g or m')
return v
class Config:
env_prefix = "SPARK_"
env_file = ".env"
def create_spark_session(config: SparkConfig):
from pyspark.sql import SparkSession
return (SparkSession.builder
.appName(config.app_name)
.master(config.master)
.config("spark.executor.memory", config.executor_memory)
.config("spark.executor.cores", config.executor_cores)
.config("spark.sql.shuffle.partitions", config.shuffle_partitions)
.config("spark.sql.adaptive.enabled", "true")
.getOrCreate())
Efficient DataFrame Operations
The single biggest performance killer in PySpark is unnecessary data shuffling. Every groupBy, join, and distinct operation potentially moves data across the cluster. Understanding this changes how you write code.
Here’s a common anti-pattern—joining a large fact table with a small dimension table:
# Anti-pattern: Regular join with small table
def get_user_orders_bad(orders_df, users_df):
return orders_df.join(users_df, "user_id") # Shuffles both datasets
The optimized version uses broadcast joins when one side is small (typically under 10MB, configurable up to 8GB):
from pyspark.sql.functions import broadcast
def get_user_orders_good(orders_df, users_df):
# Broadcast the smaller DataFrame to all executors
return orders_df.join(broadcast(users_df), "user_id")
Column pruning matters more than you think. Spark’s optimizer helps, but being explicit improves readability and ensures optimization:
# Anti-pattern: Select all, filter later
def process_events_bad(events_df):
return (events_df
.filter(events_df.event_type == "purchase")
.groupBy("user_id")
.count())
# Better: Select only needed columns early
def process_events_good(events_df):
return (events_df
.select("user_id", "event_type") # Prune columns first
.filter(events_df.event_type == "purchase")
.groupBy("user_id")
.count())
Caching is often overused. Only cache when you’re reusing a DataFrame multiple times AND the computation is expensive. Unnecessary caching wastes memory and can cause spills to disk:
def compute_with_reuse(events_df):
# Good use of cache: expensive computation reused multiple times
user_metrics = (events_df
.groupBy("user_id")
.agg(
count("*").alias("event_count"),
sum("revenue").alias("total_revenue")
)
.cache()) # Will be used for both outputs
high_value = user_metrics.filter(col("total_revenue") > 1000)
low_engagement = user_metrics.filter(col("event_count") < 5)
# Don't forget to unpersist when done
user_metrics.unpersist()
return high_value, low_engagement
Schema Management and Data Validation
Never rely on schema inference in production. It’s slow (requires a full data scan), inconsistent (different files may infer differently), and fragile (empty partitions cause failures). Define schemas explicitly:
from pyspark.sql.types import (
StructType, StructField, StringType,
LongType, TimestampType, DoubleType
)
EVENT_SCHEMA = StructType([
StructField("event_id", StringType(), nullable=False),
StructField("user_id", StringType(), nullable=False),
StructField("event_type", StringType(), nullable=False),
StructField("timestamp", TimestampType(), nullable=False),
StructField("properties", StringType(), nullable=True), # JSON string
StructField("revenue", DoubleType(), nullable=True),
])
def read_events(spark, path: str):
return spark.read.schema(EVENT_SCHEMA).parquet(path)
Data validation should happen early in your pipeline. Fail fast on bad data rather than propagating garbage:
from pyspark.sql.functions import col, when, count
class DataValidationError(Exception):
pass
def validate_events(df, max_null_ratio: float = 0.01):
"""Validate event data quality, raise on failures."""
total_count = df.count()
if total_count == 0:
raise DataValidationError("Empty DataFrame received")
# Check null ratios for required fields
null_counts = df.select([
count(when(col(c).isNull(), 1)).alias(c)
for c in ["event_id", "user_id", "event_type"]
]).collect()[0]
for field in ["event_id", "user_id", "event_type"]:
null_ratio = null_counts[field] / total_count
if null_ratio > max_null_ratio:
raise DataValidationError(
f"Field {field} has {null_ratio:.2%} nulls, "
f"exceeds threshold {max_null_ratio:.2%}"
)
return df
Error Handling and Logging
Spark’s distributed nature makes error handling tricky. Exceptions on executors don’t automatically propagate cleanly. Configure logging properly with a log4j2.properties file:
# log4j2.properties
rootLogger.level = WARN
rootLogger.appenderRef.stdout.ref = console
appender.console.type = Console
appender.console.name = console
appender.console.layout.type = PatternLayout
appender.console.layout.pattern = %d{yyyy-MM-dd HH:mm:ss} %-5p [%t] %c{1}:%L - %m%n
# Your application's logging
logger.myapp.name = my_spark_project
logger.myapp.level = INFO
For Python logging that works with Spark:
import logging
def get_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
return logger
# Usage in your job
logger = get_logger(__name__)
def run_job(spark, config):
logger.info(f"Starting job with config: {config}")
try:
df = read_events(spark, config.input_path)
logger.info(f"Read {df.count()} events")
# ... processing
except Exception as e:
logger.error(f"Job failed: {e}", exc_info=True)
raise
Testing PySpark Applications
Testing Spark code requires a local SparkSession. Use pytest fixtures to manage session lifecycle:
import pytest
from pyspark.sql import SparkSession
@pytest.fixture(scope="session")
def spark():
"""Create a SparkSession for testing."""
spark = (SparkSession.builder
.master("local[2]")
.appName("pytest")
.config("spark.sql.shuffle.partitions", "2")
.config("spark.ui.enabled", "false")
.getOrCreate())
yield spark
spark.stop()
@pytest.fixture
def sample_events(spark):
"""Create sample event data for testing."""
data = [
("e1", "u1", "purchase", 100.0),
("e2", "u1", "view", None),
("e3", "u2", "purchase", 50.0),
]
return spark.createDataFrame(
data, ["event_id", "user_id", "event_type", "revenue"]
)
Test transformations in isolation:
from my_spark_project.transformations.user_metrics import calculate_user_revenue
def test_calculate_user_revenue(spark, sample_events):
result = calculate_user_revenue(sample_events)
result_dict = {row.user_id: row.total_revenue for row in result.collect()}
assert result_dict["u1"] == 100.0
assert result_dict["u2"] == 50.0
Resource Optimization and Monitoring
Tuning Spark resources is part science, part experimentation. Start with these guidelines:
- Executor memory: 4-8GB is usually optimal. Larger heaps cause GC issues.
- Executor cores: 2-5 cores per executor. More cores mean more concurrent tasks but also more memory pressure.
- Parallelism: Set
spark.sql.shuffle.partitionsto 2-3x your total executor cores.
Use accumulators to track job metrics:
from pyspark import AccumulatorParam
def run_with_monitoring(spark, df):
records_processed = spark.sparkContext.accumulator(0)
errors_encountered = spark.sparkContext.accumulator(0)
def process_partition(iterator):
count = 0
for row in iterator:
try:
# Process row
count += 1
yield row
except Exception:
errors_encountered.add(1)
records_processed.add(count)
result = df.rdd.mapPartitions(process_partition).toDF(df.schema)
result.write.parquet("/output") # Trigger execution
print(f"Processed: {records_processed.value}, Errors: {errors_encountered.value}")
Deployment Patterns
Package your application properly for cluster submission. A typical spark-submit script:
#!/bin/bash
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
APP_VERSION="${APP_VERSION:-1.0.0}"
spark-submit \
--master yarn \
--deploy-mode cluster \
--num-executors 20 \
--executor-memory 8g \
--executor-cores 4 \
--driver-memory 4g \
--conf spark.sql.adaptive.enabled=true \
--conf spark.sql.adaptive.coalescePartitions.enabled=true \
--py-files "${SCRIPT_DIR}/dist/my_spark_project-${APP_VERSION}.zip" \
"${SCRIPT_DIR}/src/my_spark_project/jobs/daily_aggregation.py" \
--config-path "s3://bucket/configs/prod.yaml"
For containerized deployments, keep your Dockerfile simple:
FROM apache/spark-py:3.5.0
USER root
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY src/ /app/src/
WORKDIR /app
USER spark
ENTRYPOINT ["/opt/spark/bin/spark-submit"]
Production PySpark isn’t about clever tricks—it’s about applying solid software engineering practices to distributed computing. Explicit schemas, proper testing, structured logging, and sensible resource configuration will prevent most production issues before they happen.