Apache Spark - Column Pruning

Column pruning is one of Spark's most impactful automatic optimizations, yet many developers never think about it—until their jobs run ten times slower than expected. The concept is straightforward:...

Key Insights

  • Column pruning is Spark’s automatic optimization that reads only the columns your query actually needs, dramatically reducing I/O and memory usage—especially with columnar formats like Parquet and ORC.
  • Certain coding patterns silently break column pruning, including SELECT *, row-based UDFs, and premature DataFrame caching, causing Spark to read far more data than necessary.
  • Always verify pruning is working by inspecting physical plans with explain("formatted") and monitoring “columns read” metrics in the Spark UI.

Introduction to Column Pruning

Column pruning is one of Spark’s most impactful automatic optimizations, yet many developers never think about it—until their jobs run ten times slower than expected. The concept is straightforward: when you query a table with 100 columns but only need 3, Spark should read just those 3 columns from storage, not the entire dataset.

This matters enormously at scale. Consider a production table with 200 columns and 10 billion rows stored in Parquet. If your aggregation only needs user_id and purchase_amount, column pruning can reduce data read from terabytes to gigabytes. That translates directly to faster job completion, lower cloud costs, and reduced cluster memory pressure.

The optimization happens automatically through Spark’s Catalyst optimizer—but only when your code cooperates. Understanding how pruning works, what breaks it, and how to verify it’s happening will make you a more effective Spark developer.

How Spark’s Catalyst Optimizer Handles Column Pruning

Spark’s Catalyst optimizer analyzes your query’s logical plan to determine which columns are actually required for the final result. It then pushes column selections down to the scan operators, ensuring the data source reads only what’s needed.

The process works through several optimization rules, with ColumnPruning being the primary one. Catalyst traces column references through your entire query—filters, joins, aggregations, projections—building a set of required columns. It then rewrites the plan to project only those columns as early as possible.

Here’s how to observe this in action:

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("ColumnPruning").getOrCreate()

# Create a sample DataFrame with many columns
df = spark.read.parquet("/data/sales_transactions")

# Query using only two columns
result = df.filter(df.region == "EMEA").select("customer_id", "revenue")

# Examine the full optimization process
result.explain(True)

The output shows multiple plan stages:

== Parsed Logical Plan ==
Project [customer_id, revenue]
+- Filter (region = EMEA)
   +- Relation[customer_id, revenue, region, product_id, timestamp, ...] parquet

== Analyzed Logical Plan ==
...

== Optimized Logical Plan ==
Project [customer_id, revenue]
+- Filter (region = EMEA)
   +- Relation[customer_id, revenue, region] parquet

== Physical Plan ==
*(1) Project [customer_id, revenue]
+- *(1) Filter (region = EMEA)
   +- *(1) ColumnarToRow
      +- FileScan parquet [customer_id,revenue,region] ReadSchema: struct<customer_id:string,revenue:double,region:string>

Notice how the ReadSchema in the physical plan contains only three columns—the two we selected plus region for the filter. Catalyst determined that’s all we need and pushed that projection down to the file scan.

Column Pruning with Different Data Sources

Column pruning’s effectiveness varies dramatically based on your data format. Columnar formats like Parquet and ORC store data by column, making it physically possible to read specific columns without touching others. Row-based formats like CSV and JSON must read entire rows regardless of which columns you need.

import time

# Parquet: columnar format - pruning is highly effective
start = time.time()
parquet_df = spark.read.parquet("/data/large_table.parquet") \
    .select("user_id", "event_type") \
    .count()
parquet_time = time.time() - start

# CSV: row-based format - must read all columns regardless
start = time.time()
csv_df = spark.read.csv("/data/large_table.csv", header=True) \
    .select("user_id", "event_type") \
    .count()
csv_time = time.time() - start

print(f"Parquet read time: {parquet_time:.2f}s")
print(f"CSV read time: {csv_time:.2f}s")

On a table with 50 columns, you’ll typically see Parquet outperform CSV by 5-10x when selecting only 2 columns. The gap widens as column count increases.

ORC provides similar benefits to Parquet, with some workloads favoring one over the other. Both support predicate pushdown alongside column pruning, compounding the optimization benefits.

For JSON data, Spark must parse entire JSON objects even when you only need specific fields. If you’re stuck with JSON sources, consider a preprocessing step that converts to Parquet for repeated analytical queries.

Common Scenarios That Break Column Pruning

Several coding patterns silently defeat column pruning, causing Spark to read far more data than your query logically requires.

The SELECT * trap: Using df.select("*") or simply referencing the DataFrame without explicit column selection forces Spark to read everything. This seems obvious, but it often hides in intermediate transformations.

Row-based UDFs: When you pass entire rows to UDFs, Spark cannot determine which columns the function actually uses, so it reads all of them:

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# BAD: This breaks column pruning
@udf(StringType())
def bad_categorize(row):
    # Function receives entire row, Spark reads all columns
    if row.revenue > 1000:
        return "high"
    return "low"

# This forces Spark to read ALL columns
result = df.select(bad_categorize(struct("*")).alias("category"))

# GOOD: Reference only needed columns
@udf(StringType())
def good_categorize(revenue):
    if revenue > 1000:
        return "high"
    return "low"

# Spark reads only the revenue column
result = df.select(good_categorize(df.revenue).alias("category"))

Premature caching: Calling .cache() or .persist() before selecting columns materializes the entire DataFrame:

# BAD: Caches all columns, then selects
df.cache()
result = df.select("col1", "col2").filter(df.col1 > 100)

# GOOD: Select first, then cache
result = df.select("col1", "col2").filter(df.col1 > 100).cache()

Certain join patterns: Broadcast joins with wide tables can pull unnecessary columns. Always select needed columns before joining.

Verifying Column Pruning in Your Queries

Don’t assume pruning is working—verify it. Spark provides several tools for inspection.

The explain("formatted") method gives the clearest view of what columns will actually be read:

df = spark.read.parquet("/data/events")

query = df.filter(df.event_date >= "2024-01-01") \
    .groupBy("user_id") \
    .agg({"revenue": "sum"})

query.explain("formatted")

Look for the ReadSchema field in the output:

(1) Scan parquet 
Output [3]: [user_id, revenue, event_date]
Batched: true
Location: InMemoryFileIndex [/data/events]
PushedFilters: [GreaterThanOrEqual(event_date,2024-01-01)]
ReadSchema: struct<user_id:string,revenue:double,event_date:date>

The ReadSchema shows exactly which columns Spark will read from storage. If you see columns here that your query doesn’t need, something is preventing pruning.

In the Spark UI, navigate to the SQL tab and examine the scan node details. The “number of output rows” and “size of files read” metrics help quantify pruning effectiveness. Compare these against the total table size to gauge how much data you’re avoiding.

Best Practices and Performance Tips

Apply these patterns consistently to ensure column pruning works effectively across your Spark applications.

Select columns as early as possible. Push column selection to immediately after reading data:

# Before: Wide DataFrame flows through multiple transformations
df = spark.read.parquet("/data/transactions")
filtered = df.filter(df.status == "completed")
joined = filtered.join(other_df, "customer_id")
result = joined.select("customer_id", "amount", "category")

# After: Narrow the DataFrame immediately
df = spark.read.parquet("/data/transactions") \
    .select("customer_id", "amount", "category", "status")
filtered = df.filter(df.status == "completed")
joined = filtered.join(other_df, "customer_id")
result = joined.select("customer_id", "amount", "category")

Use schema projection on read. For maximum control, specify the schema when reading:

from pyspark.sql.types import StructType, StructField, StringType, DoubleType

# Define only the columns you need
schema = StructType([
    StructField("customer_id", StringType(), True),
    StructField("revenue", DoubleType(), True),
    StructField("region", StringType(), True)
])

df = spark.read.schema(schema).parquet("/data/sales")

This guarantees Spark never even considers other columns during planning.

Benchmark your changes. Measure the impact of pruning optimizations:

import time

def benchmark_query(description, query_func, iterations=3):
    times = []
    for _ in range(iterations):
        start = time.time()
        query_func().count()  # Force execution
        times.append(time.time() - start)
    avg_time = sum(times) / len(times)
    print(f"{description}: {avg_time:.2f}s average")

# Compare approaches
benchmark_query("All columns", 
    lambda: spark.read.parquet("/data/large").filter(col("status") == "active"))

benchmark_query("Pruned columns",
    lambda: spark.read.parquet("/data/large").select("id", "status").filter(col("status") == "active"))

Column pruning is a foundational optimization that compounds with other Spark optimizations like predicate pushdown and partition pruning. Master it, verify it’s working, and you’ll consistently write faster, more efficient Spark applications.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.