ETL Pipeline with PySpark - Complete Tutorial
ETL—Extract, Transform, Load—forms the backbone of modern data engineering. You pull data from source systems, clean and reshape it, then push it somewhere useful. Simple concept, complex execution.
Key Insights
- PySpark ETL pipelines shine when processing data at scale—start simple with DataFrames, then optimize based on actual bottlenecks rather than premature optimization
- Schema definition upfront prevents silent data corruption and makes debugging infinitely easier than relying on schema inference in production
- Partition your output data by columns you’ll frequently filter on, but avoid over-partitioning which creates thousands of small files that kill read performance
Introduction to ETL and PySpark
ETL—Extract, Transform, Load—forms the backbone of modern data engineering. You pull data from source systems, clean and reshape it, then push it somewhere useful. Simple concept, complex execution.
PySpark makes this manageable at scale. When your CSV files grow from megabytes to terabytes, pandas falls over. PySpark distributes work across a cluster, processing data in parallel while you write familiar Python code. You get the Python ecosystem you know plus Apache Spark’s distributed computing power.
This tutorial assumes you have Python 3.8+ and can install packages. We’ll build a complete ETL pipeline from scratch, covering real-world patterns I’ve used in production systems processing billions of records.
Setting Up Your PySpark Environment
Install PySpark via pip:
pip install pyspark==3.5.0
For local development, that’s it. PySpark bundles a Spark distribution. For production clusters, you’ll coordinate with your infrastructure team on Spark versions and cluster managers.
Every PySpark application starts with a SparkSession:
from pyspark.sql import SparkSession
def create_spark_session(app_name: str = "ETL_Pipeline") -> SparkSession:
return (
SparkSession.builder
.appName(app_name)
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.sql.parquet.compression.codec", "snappy")
.getOrCreate()
)
spark = create_spark_session()
Adaptive Query Execution (AQE) dynamically optimizes query plans at runtime—enable it. Kryo serialization outperforms Java serialization for most workloads. These configs give you solid defaults.
Understand the architecture: your driver program orchestrates work, executors do the actual processing, and data lives in partitions distributed across executors. More partitions means more parallelism, but also more overhead.
Extract: Reading Data from Multiple Sources
Reading data seems trivial until you hit malformed records, schema mismatches, or connection timeouts. Define schemas explicitly:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, DoubleType
# Define schema explicitly - don't rely on inference in production
orders_schema = StructType([
StructField("order_id", StringType(), nullable=False),
StructField("customer_id", StringType(), nullable=False),
StructField("product_id", StringType(), nullable=False),
StructField("quantity", IntegerType(), nullable=False),
StructField("unit_price", DoubleType(), nullable=False),
StructField("order_date", TimestampType(), nullable=False),
StructField("status", StringType(), nullable=True)
])
def extract_csv(spark: SparkSession, path: str, schema: StructType):
return (
spark.read
.option("header", "true")
.option("mode", "DROPMALFORMED") # or PERMISSIVE with _corrupt_record
.option("timestampFormat", "yyyy-MM-dd HH:mm:ss")
.schema(schema)
.csv(path)
)
orders_df = extract_csv(spark, "s3://bucket/raw/orders/*.csv", orders_schema)
Schema inference scans your data twice and guesses types. In production, a single malformed file can change your inferred schema and break downstream processes. Define it once, enforce it always.
For database sources, use JDBC:
def extract_from_postgres(spark: SparkSession, table: str, connection_props: dict):
jdbc_url = f"jdbc:postgresql://{connection_props['host']}:{connection_props['port']}/{connection_props['database']}"
return (
spark.read
.format("jdbc")
.option("url", jdbc_url)
.option("dbtable", table)
.option("user", connection_props["user"])
.option("password", connection_props["password"])
.option("driver", "org.postgresql.Driver")
.option("fetchsize", "10000") # Tune based on row size
.option("partitionColumn", "id") # Enable parallel reads
.option("lowerBound", "1")
.option("upperBound", "1000000")
.option("numPartitions", "10")
.load()
)
The partition options enable parallel reads from the database. Without them, a single executor fetches everything sequentially.
Transform: Data Processing and Cleaning
Transformations are where your business logic lives. Start with the common operations:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
def clean_orders(df):
"""Apply standard cleaning transformations."""
return (
df
# Remove duplicates
.dropDuplicates(["order_id"])
# Handle nulls
.fillna({"status": "unknown"})
.filter(F.col("quantity") > 0)
# Standardize strings
.withColumn("status", F.lower(F.trim(F.col("status"))))
# Add derived columns
.withColumn("total_amount", F.col("quantity") * F.col("unit_price"))
.withColumn("processed_at", F.current_timestamp())
)
Joins are common but expensive. Understand your data sizes:
def enrich_orders_with_customers(orders_df, customers_df):
"""Join orders with customer data."""
# If customers_df is small (< 10MB), broadcast it
customers_small = F.broadcast(customers_df.select(
"customer_id", "customer_name", "region", "segment"
))
return orders_df.join(
customers_small,
on="customer_id",
how="left"
)
For custom logic that doesn’t fit built-in functions, use UDFs sparingly:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import re
@udf(returnType=StringType())
def validate_email(email):
"""Validate and normalize email addresses."""
if email is None:
return None
email = email.lower().strip()
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
return email if re.match(pattern, email) else None
# Apply UDF
customers_df = customers_df.withColumn(
"validated_email",
validate_email(F.col("email"))
)
UDFs serialize data to Python and back, killing performance. Prefer built-in functions when possible. For heavy UDF usage, consider pandas UDFs which process batches using Arrow.
Window functions handle running totals, rankings, and time-based calculations:
def add_customer_metrics(orders_df):
"""Add customer-level aggregations using window functions."""
customer_window = Window.partitionBy("customer_id")
order_window = (
Window
.partitionBy("customer_id")
.orderBy("order_date")
)
return (
orders_df
.withColumn("customer_total_orders", F.count("order_id").over(customer_window))
.withColumn("customer_lifetime_value", F.sum("total_amount").over(customer_window))
.withColumn("order_sequence", F.row_number().over(order_window))
.withColumn("days_since_last_order",
F.datediff(
F.col("order_date"),
F.lag("order_date").over(order_window)
)
)
)
Load: Writing Data to Target Systems
Writing data efficiently requires understanding partitioning and file formats:
def load_to_data_lake(df, output_path: str, partition_cols: list = None):
"""Write DataFrame to Parquet with optional partitioning."""
writer = (
df.write
.mode("overwrite")
.option("compression", "snappy")
)
if partition_cols:
writer = writer.partitionBy(*partition_cols)
writer.parquet(output_path)
# Partition by date for time-series queries
load_to_data_lake(
orders_df.withColumn("order_date_partition", F.to_date("order_date")),
"s3://bucket/processed/orders",
partition_cols=["order_date_partition"]
)
For Delta Lake, you get ACID transactions and upserts:
from delta.tables import DeltaTable
def upsert_to_delta(spark, df, table_path: str, merge_keys: list):
"""Upsert data to Delta Lake table."""
if DeltaTable.isDeltaTable(spark, table_path):
delta_table = DeltaTable.forPath(spark, table_path)
merge_condition = " AND ".join([
f"target.{key} = source.{key}" for key in merge_keys
])
(
delta_table.alias("target")
.merge(df.alias("source"), merge_condition)
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.execute()
)
else:
df.write.format("delta").save(table_path)
Building a Production-Ready Pipeline
Wrap everything in a structured class with proper error handling:
import logging
from dataclasses import dataclass
from typing import Optional
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class PipelineConfig:
source_path: str
customers_path: str
output_path: str
run_date: str
class OrdersETLPipeline:
def __init__(self, spark: SparkSession, config: PipelineConfig):
self.spark = spark
self.config = config
def run(self) -> bool:
"""Execute the complete ETL pipeline."""
try:
logger.info(f"Starting pipeline for {self.config.run_date}")
# Extract
logger.info("Extracting data...")
orders_df = extract_csv(self.spark, self.config.source_path, orders_schema)
customers_df = self.spark.read.parquet(self.config.customers_path)
record_count = orders_df.count()
logger.info(f"Extracted {record_count} orders")
if record_count == 0:
logger.warning("No records extracted, skipping pipeline")
return True
# Transform
logger.info("Transforming data...")
cleaned_df = clean_orders(orders_df)
enriched_df = enrich_orders_with_customers(cleaned_df, customers_df)
final_df = add_customer_metrics(enriched_df)
# Load
logger.info("Loading data...")
load_to_data_lake(
final_df,
f"{self.config.output_path}/orders",
partition_cols=["order_date_partition"]
)
logger.info("Pipeline completed successfully")
return True
except Exception as e:
logger.error(f"Pipeline failed: {str(e)}", exc_info=True)
raise
# Usage
if __name__ == "__main__":
spark = create_spark_session()
config = PipelineConfig(
source_path="s3://bucket/raw/orders/",
customers_path="s3://bucket/dimensions/customers/",
output_path="s3://bucket/processed/",
run_date="2024-01-15"
)
pipeline = OrdersETLPipeline(spark, config)
pipeline.run()
Performance Optimization and Best Practices
Most performance issues come from three sources: shuffles, data skew, and small files.
Reduce shuffles by filtering early and using broadcast joins for small tables:
# Bad: Filter after join
result = large_df.join(other_df, "key").filter(F.col("status") == "active")
# Good: Filter before join
filtered = large_df.filter(F.col("status") == "active")
result = filtered.join(F.broadcast(other_df), "key")
Handle skew by salting keys or using AQE:
# If one customer_id has millions of records, salt it
df_salted = df.withColumn("salt", F.concat(F.col("customer_id"), F.lit("_"), (F.rand() * 10).cast("int")))
Avoid small files by coalescing before writes:
# If you have 1000 partitions but only 100MB of data
df.coalesce(10).write.parquet(output_path)
Cache DataFrames you reuse multiple times:
orders_df.cache()
orders_df.count() # Materialize the cache
# Use it multiple times
summary_df = orders_df.groupBy("status").count()
recent_df = orders_df.filter(F.col("order_date") > "2024-01-01")
orders_df.unpersist() # Clean up when done
Monitor your jobs through Spark UI at http://localhost:4040. Look for stages with high shuffle read/write, tasks with uneven durations (skew), and spill to disk.
Build your pipelines incrementally. Start with working code, measure actual performance, then optimize the bottlenecks you find. PySpark handles more than you’d expect before optimization becomes necessary.