PySpark - Common Mistakes and How to Avoid Them

PySpark promises distributed computing at scale, but developers transitioning from pandas or traditional Python consistently fall into the same traps. The mental model shift is significant: you're no...

Key Insights

  • Collecting large datasets to the driver node is the fastest way to crash your Spark job—use .take(), .show(), or .toLocalIterator() instead of .collect() for debugging and sampling.
  • Python UDFs serialize data between the JVM and Python interpreter, causing 10-100x slowdowns compared to native Spark SQL functions—always check if a built-in function exists first.
  • Data skew silently kills performance by overloading single executors while others sit idle—monitor partition sizes and apply salting techniques for skewed joins.

Introduction

PySpark promises distributed computing at scale, but developers transitioning from pandas or traditional Python consistently fall into the same traps. The mental model shift is significant: you’re no longer working with data that fits in memory on a single machine. You’re writing transformation plans that Spark’s optimizer will execute across a cluster.

The mistakes covered here aren’t edge cases. They’re patterns I’ve seen crash production jobs, balloon cloud bills, and turn 10-minute jobs into 10-hour nightmares. Understanding why these patterns fail—not just that they fail—will make you a more effective Spark developer.

Collecting Large Datasets to the Driver

The .collect() method brings every row from your distributed DataFrame to the driver node’s memory. For a 100-row test DataFrame, this works fine. For a 100-million-row production dataset, you’ll get an OutOfMemoryError and a crashed job.

# DON'T DO THIS - will crash on large datasets
df = spark.read.parquet("s3://data-lake/events/")
all_rows = df.collect()  # Pulls entire dataset to driver memory

# BETTER - take a sample for debugging
sample_rows = df.take(10)  # Only pulls 10 rows

# BETTER - display in notebook without collecting
df.show(20, truncate=False)  # Shows 20 rows, prints to stdout

# BETTER - limit then collect if you need a list
small_sample = df.limit(100).collect()  # Caps at 100 rows first

# BEST for iteration - process in chunks
for row in df.toLocalIterator():
    # Processes one partition at a time, memory-efficient
    process_row(row)

The key insight is that .collect() has no safety valve. It will attempt to pull terabytes of data into gigabytes of driver memory. Use .limit() before .collect() as a defensive pattern, or use .toLocalIterator() when you genuinely need to process all rows locally but can do so incrementally.

When debugging, .show() is almost always what you want. It’s designed for human inspection and handles truncation gracefully.

Using Python UDFs When Native Functions Exist

Python UDFs are convenient but expensive. When you call a Python UDF, Spark must serialize each row from the JVM, send it to a Python worker process, execute your function, and serialize the result back. This serialization overhead dominates execution time.

from pyspark.sql import functions as F
from pyspark.sql.types import StringType

# SLOW - Python UDF for simple string operation
@udf(returnType=StringType())
def clean_email_udf(email):
    if email is None:
        return None
    return email.lower().strip()

df_slow = df.withColumn("clean_email", clean_email_udf(F.col("email")))

# FAST - Native Spark functions (10-100x faster)
df_fast = df.withColumn("clean_email", F.lower(F.trim(F.col("email"))))

The performance difference is dramatic. I’ve seen jobs drop from 45 minutes to 3 minutes just by replacing UDFs with native functions. The pyspark.sql.functions module is extensive—check it before writing a UDF.

# More UDF replacements

# SLOW UDF for date extraction
@udf(returnType=StringType())
def extract_year_udf(date_str):
    return date_str[:4] if date_str else None

# FAST native equivalent
df.withColumn("year", F.year(F.col("date_column")))

# SLOW UDF for conditional logic
@udf(returnType=StringType())
def categorize_udf(amount):
    if amount > 1000:
        return "high"
    elif amount > 100:
        return "medium"
    return "low"

# FAST native equivalent
df.withColumn("category", 
    F.when(F.col("amount") > 1000, "high")
     .when(F.col("amount") > 100, "medium")
     .otherwise("low")
)

If you must use a UDF, consider Pandas UDFs (vectorized UDFs) which operate on batches using Apache Arrow for faster serialization. But native functions should always be your first choice.

Ignoring Data Skew and Partition Imbalance

Data skew occurs when some partitions contain vastly more data than others. Spark processes partitions in parallel, so your job completes only when the slowest partition finishes. One partition with 10 million rows while others have 10 thousand means 99% of your cluster sits idle waiting for that single executor.

# Identify skew by checking partition sizes
from pyspark.sql import functions as F

# Check distribution of a join key
df.groupBy("customer_id").count().orderBy(F.desc("count")).show(20)

# Check actual partition sizes
def get_partition_sizes(df):
    return df.rdd.mapPartitions(
        lambda partition: [sum(1 for _ in partition)]
    ).collect()

partition_sizes = get_partition_sizes(df)
print(f"Min: {min(partition_sizes)}, Max: {max(partition_sizes)}, "
      f"Ratio: {max(partition_sizes)/max(min(partition_sizes), 1):.1f}x")

When you find skew, salting is the standard fix for joins. You add random noise to the skewed key, spreading hot keys across multiple partitions:

import random

# Salting technique for skewed joins
SALT_BUCKETS = 10

# Salt the large (skewed) table
large_df_salted = large_df.withColumn(
    "salt", (F.rand() * SALT_BUCKETS).cast("int")
).withColumn(
    "salted_key", F.concat(F.col("join_key"), F.lit("_"), F.col("salt"))
)

# Explode the small table to match all salt values
small_df_exploded = small_df.crossJoin(
    spark.range(SALT_BUCKETS).withColumnRenamed("id", "salt")
).withColumn(
    "salted_key", F.concat(F.col("join_key"), F.lit("_"), F.col("salt"))
)

# Join on salted key - now distributed evenly
result = large_df_salted.join(
    small_df_exploded, 
    on="salted_key", 
    how="inner"
).drop("salt", "salted_key")

Inefficient Join Strategies

Spark supports multiple join strategies, and choosing the wrong one can tank performance. The two main strategies are shuffle joins (both tables redistributed by join key) and broadcast joins (small table sent to all executors).

# Check your join strategy with explain()
df1.join(df2, on="key").explain(mode="formatted")

# Force broadcast join for small tables (< 10MB default threshold)
from pyspark.sql.functions import broadcast

# SLOW - shuffle join (both tables redistributed)
result_slow = large_df.join(small_lookup_df, on="category_id")

# FAST - broadcast join (small table sent to all executors)
result_fast = large_df.join(broadcast(small_lookup_df), on="category_id")

The query plan tells you what Spark chose. Look for BroadcastHashJoin (good for small tables) versus SortMergeJoin (necessary for large-large joins but expensive).

# Reading the explain output
result.explain()
# Look for:
# - BroadcastHashJoin: Small table broadcast, fast
# - SortMergeJoin: Both tables shuffled and sorted, slower
# - BroadcastNestedLoopJoin: Usually bad, indicates cartesian-like operation

# Filter BEFORE joining to reduce shuffle size
# BAD - filter after join
result_bad = orders.join(customers, on="customer_id").filter(
    F.col("order_date") > "2024-01-01"
)

# GOOD - filter before join
result_good = orders.filter(
    F.col("order_date") > "2024-01-01"
).join(customers, on="customer_id")

Join order matters too. When joining multiple tables, start with the most selective filters and smallest intermediate results.

Repeated Computations Without Caching

Spark uses lazy evaluation—transformations don’t execute until an action triggers them. Each action recomputes the entire lineage from scratch. If you use a DataFrame multiple times, you’re recomputing it each time.

# BAD - df_transformed computed twice
df_transformed = (
    spark.read.parquet("s3://bucket/raw/")
    .filter(F.col("status") == "active")
    .withColumn("score", complex_calculation())
)

count = df_transformed.count()  # Computes df_transformed
df_transformed.write.parquet("s3://bucket/output/")  # Computes again!

# GOOD - cache intermediate result
df_transformed = (
    spark.read.parquet("s3://bucket/raw/")
    .filter(F.col("status") == "active")
    .withColumn("score", complex_calculation())
).cache()  # Or .persist(StorageLevel.MEMORY_AND_DISK)

count = df_transformed.count()  # Computes and caches
df_transformed.write.parquet("s3://bucket/output/")  # Uses cache

# Clean up when done
df_transformed.unpersist()

Choose your persistence level based on data size and reuse patterns:

from pyspark import StorageLevel

# Memory only - fastest, but spills if too large
df.persist(StorageLevel.MEMORY_ONLY)

# Memory and disk - safer, spills to disk if needed
df.persist(StorageLevel.MEMORY_AND_DISK)

# Disk only - for very large datasets you'll reuse
df.persist(StorageLevel.DISK_ONLY)

# Serialized - uses less memory but slower to read
df.persist(StorageLevel.MEMORY_ONLY_SER)

Always unpersist when you’re done. Cached DataFrames consume cluster memory that could be used for computation.

Conclusion

These mistakes share a common thread: they ignore the distributed nature of Spark. Every operation has a cost measured in network transfer, serialization, and coordination overhead.

Your optimization checklist:

  • Never .collect() without a .limit() first
  • Replace Python UDFs with native pyspark.sql.functions
  • Monitor partition sizes and apply salting for skewed joins
  • Use broadcast() hints for small lookup tables
  • Cache DataFrames used multiple times, unpersist when done
  • Read .explain() output to understand your query plans

The Spark UI is your best friend for ongoing optimization. Check the Stages tab for skewed tasks, the Storage tab for cached DataFrames, and the SQL tab for query plan visualization. These patterns will get you 80% of the way to performant Spark jobs—the Spark UI will help you find the remaining 20%.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.