ETL Pipeline with PySpark - Complete Tutorial

ETL—Extract, Transform, Load—forms the backbone of modern data engineering. You pull data from source systems, clean and reshape it, then push it somewhere useful. Simple concept, complex execution.

Key Insights

  • PySpark ETL pipelines shine when processing data at scale—start simple with DataFrames, then optimize based on actual bottlenecks rather than premature optimization
  • Schema definition upfront prevents silent data corruption and makes debugging infinitely easier than relying on schema inference in production
  • Partition your output data by columns you’ll frequently filter on, but avoid over-partitioning which creates thousands of small files that kill read performance

Introduction to ETL and PySpark

ETL—Extract, Transform, Load—forms the backbone of modern data engineering. You pull data from source systems, clean and reshape it, then push it somewhere useful. Simple concept, complex execution.

PySpark makes this manageable at scale. When your CSV files grow from megabytes to terabytes, pandas falls over. PySpark distributes work across a cluster, processing data in parallel while you write familiar Python code. You get the Python ecosystem you know plus Apache Spark’s distributed computing power.

This tutorial assumes you have Python 3.8+ and can install packages. We’ll build a complete ETL pipeline from scratch, covering real-world patterns I’ve used in production systems processing billions of records.

Setting Up Your PySpark Environment

Install PySpark via pip:

pip install pyspark==3.5.0

For local development, that’s it. PySpark bundles a Spark distribution. For production clusters, you’ll coordinate with your infrastructure team on Spark versions and cluster managers.

Every PySpark application starts with a SparkSession:

from pyspark.sql import SparkSession

def create_spark_session(app_name: str = "ETL_Pipeline") -> SparkSession:
    return (
        SparkSession.builder
        .appName(app_name)
        .config("spark.sql.adaptive.enabled", "true")
        .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
        .config("spark.sql.parquet.compression.codec", "snappy")
        .getOrCreate()
    )

spark = create_spark_session()

Adaptive Query Execution (AQE) dynamically optimizes query plans at runtime—enable it. Kryo serialization outperforms Java serialization for most workloads. These configs give you solid defaults.

Understand the architecture: your driver program orchestrates work, executors do the actual processing, and data lives in partitions distributed across executors. More partitions means more parallelism, but also more overhead.

Extract: Reading Data from Multiple Sources

Reading data seems trivial until you hit malformed records, schema mismatches, or connection timeouts. Define schemas explicitly:

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, DoubleType

# Define schema explicitly - don't rely on inference in production
orders_schema = StructType([
    StructField("order_id", StringType(), nullable=False),
    StructField("customer_id", StringType(), nullable=False),
    StructField("product_id", StringType(), nullable=False),
    StructField("quantity", IntegerType(), nullable=False),
    StructField("unit_price", DoubleType(), nullable=False),
    StructField("order_date", TimestampType(), nullable=False),
    StructField("status", StringType(), nullable=True)
])

def extract_csv(spark: SparkSession, path: str, schema: StructType):
    return (
        spark.read
        .option("header", "true")
        .option("mode", "DROPMALFORMED")  # or PERMISSIVE with _corrupt_record
        .option("timestampFormat", "yyyy-MM-dd HH:mm:ss")
        .schema(schema)
        .csv(path)
    )

orders_df = extract_csv(spark, "s3://bucket/raw/orders/*.csv", orders_schema)

Schema inference scans your data twice and guesses types. In production, a single malformed file can change your inferred schema and break downstream processes. Define it once, enforce it always.

For database sources, use JDBC:

def extract_from_postgres(spark: SparkSession, table: str, connection_props: dict):
    jdbc_url = f"jdbc:postgresql://{connection_props['host']}:{connection_props['port']}/{connection_props['database']}"
    
    return (
        spark.read
        .format("jdbc")
        .option("url", jdbc_url)
        .option("dbtable", table)
        .option("user", connection_props["user"])
        .option("password", connection_props["password"])
        .option("driver", "org.postgresql.Driver")
        .option("fetchsize", "10000")  # Tune based on row size
        .option("partitionColumn", "id")  # Enable parallel reads
        .option("lowerBound", "1")
        .option("upperBound", "1000000")
        .option("numPartitions", "10")
        .load()
    )

The partition options enable parallel reads from the database. Without them, a single executor fetches everything sequentially.

Transform: Data Processing and Cleaning

Transformations are where your business logic lives. Start with the common operations:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

def clean_orders(df):
    """Apply standard cleaning transformations."""
    return (
        df
        # Remove duplicates
        .dropDuplicates(["order_id"])
        # Handle nulls
        .fillna({"status": "unknown"})
        .filter(F.col("quantity") > 0)
        # Standardize strings
        .withColumn("status", F.lower(F.trim(F.col("status"))))
        # Add derived columns
        .withColumn("total_amount", F.col("quantity") * F.col("unit_price"))
        .withColumn("processed_at", F.current_timestamp())
    )

Joins are common but expensive. Understand your data sizes:

def enrich_orders_with_customers(orders_df, customers_df):
    """Join orders with customer data."""
    # If customers_df is small (< 10MB), broadcast it
    customers_small = F.broadcast(customers_df.select(
        "customer_id", "customer_name", "region", "segment"
    ))
    
    return orders_df.join(
        customers_small,
        on="customer_id",
        how="left"
    )

For custom logic that doesn’t fit built-in functions, use UDFs sparingly:

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import re

@udf(returnType=StringType())
def validate_email(email):
    """Validate and normalize email addresses."""
    if email is None:
        return None
    email = email.lower().strip()
    pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
    return email if re.match(pattern, email) else None

# Apply UDF
customers_df = customers_df.withColumn(
    "validated_email", 
    validate_email(F.col("email"))
)

UDFs serialize data to Python and back, killing performance. Prefer built-in functions when possible. For heavy UDF usage, consider pandas UDFs which process batches using Arrow.

Window functions handle running totals, rankings, and time-based calculations:

def add_customer_metrics(orders_df):
    """Add customer-level aggregations using window functions."""
    customer_window = Window.partitionBy("customer_id")
    
    order_window = (
        Window
        .partitionBy("customer_id")
        .orderBy("order_date")
    )
    
    return (
        orders_df
        .withColumn("customer_total_orders", F.count("order_id").over(customer_window))
        .withColumn("customer_lifetime_value", F.sum("total_amount").over(customer_window))
        .withColumn("order_sequence", F.row_number().over(order_window))
        .withColumn("days_since_last_order", 
            F.datediff(
                F.col("order_date"),
                F.lag("order_date").over(order_window)
            )
        )
    )

Load: Writing Data to Target Systems

Writing data efficiently requires understanding partitioning and file formats:

def load_to_data_lake(df, output_path: str, partition_cols: list = None):
    """Write DataFrame to Parquet with optional partitioning."""
    writer = (
        df.write
        .mode("overwrite")
        .option("compression", "snappy")
    )
    
    if partition_cols:
        writer = writer.partitionBy(*partition_cols)
    
    writer.parquet(output_path)

# Partition by date for time-series queries
load_to_data_lake(
    orders_df.withColumn("order_date_partition", F.to_date("order_date")),
    "s3://bucket/processed/orders",
    partition_cols=["order_date_partition"]
)

For Delta Lake, you get ACID transactions and upserts:

from delta.tables import DeltaTable

def upsert_to_delta(spark, df, table_path: str, merge_keys: list):
    """Upsert data to Delta Lake table."""
    
    if DeltaTable.isDeltaTable(spark, table_path):
        delta_table = DeltaTable.forPath(spark, table_path)
        
        merge_condition = " AND ".join([
            f"target.{key} = source.{key}" for key in merge_keys
        ])
        
        (
            delta_table.alias("target")
            .merge(df.alias("source"), merge_condition)
            .whenMatchedUpdateAll()
            .whenNotMatchedInsertAll()
            .execute()
        )
    else:
        df.write.format("delta").save(table_path)

Building a Production-Ready Pipeline

Wrap everything in a structured class with proper error handling:

import logging
from dataclasses import dataclass
from typing import Optional

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class PipelineConfig:
    source_path: str
    customers_path: str
    output_path: str
    run_date: str

class OrdersETLPipeline:
    def __init__(self, spark: SparkSession, config: PipelineConfig):
        self.spark = spark
        self.config = config
    
    def run(self) -> bool:
        """Execute the complete ETL pipeline."""
        try:
            logger.info(f"Starting pipeline for {self.config.run_date}")
            
            # Extract
            logger.info("Extracting data...")
            orders_df = extract_csv(self.spark, self.config.source_path, orders_schema)
            customers_df = self.spark.read.parquet(self.config.customers_path)
            
            record_count = orders_df.count()
            logger.info(f"Extracted {record_count} orders")
            
            if record_count == 0:
                logger.warning("No records extracted, skipping pipeline")
                return True
            
            # Transform
            logger.info("Transforming data...")
            cleaned_df = clean_orders(orders_df)
            enriched_df = enrich_orders_with_customers(cleaned_df, customers_df)
            final_df = add_customer_metrics(enriched_df)
            
            # Load
            logger.info("Loading data...")
            load_to_data_lake(
                final_df,
                f"{self.config.output_path}/orders",
                partition_cols=["order_date_partition"]
            )
            
            logger.info("Pipeline completed successfully")
            return True
            
        except Exception as e:
            logger.error(f"Pipeline failed: {str(e)}", exc_info=True)
            raise

# Usage
if __name__ == "__main__":
    spark = create_spark_session()
    config = PipelineConfig(
        source_path="s3://bucket/raw/orders/",
        customers_path="s3://bucket/dimensions/customers/",
        output_path="s3://bucket/processed/",
        run_date="2024-01-15"
    )
    
    pipeline = OrdersETLPipeline(spark, config)
    pipeline.run()

Performance Optimization and Best Practices

Most performance issues come from three sources: shuffles, data skew, and small files.

Reduce shuffles by filtering early and using broadcast joins for small tables:

# Bad: Filter after join
result = large_df.join(other_df, "key").filter(F.col("status") == "active")

# Good: Filter before join
filtered = large_df.filter(F.col("status") == "active")
result = filtered.join(F.broadcast(other_df), "key")

Handle skew by salting keys or using AQE:

# If one customer_id has millions of records, salt it
df_salted = df.withColumn("salt", F.concat(F.col("customer_id"), F.lit("_"), (F.rand() * 10).cast("int")))

Avoid small files by coalescing before writes:

# If you have 1000 partitions but only 100MB of data
df.coalesce(10).write.parquet(output_path)

Cache DataFrames you reuse multiple times:

orders_df.cache()
orders_df.count()  # Materialize the cache

# Use it multiple times
summary_df = orders_df.groupBy("status").count()
recent_df = orders_df.filter(F.col("order_date") > "2024-01-01")

orders_df.unpersist()  # Clean up when done

Monitor your jobs through Spark UI at http://localhost:4040. Look for stages with high shuffle read/write, tasks with uneven durations (skew), and spill to disk.

Build your pipelines incrementally. Start with working code, measure actual performance, then optimize the bottlenecks you find. PySpark handles more than you’d expect before optimization becomes necessary.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.