PySpark - Add Multiple Columns to DataFrame

Adding multiple columns to PySpark DataFrames is one of the most common operations in data engineering and machine learning pipelines. Whether you're performing feature engineering, calculating...

Key Insights

  • Chaining withColumn() calls is readable but creates multiple DataFrame copies, while select() and withColumns() (PySpark 3.3+) perform better by processing columns in a single pass
  • The withColumns() method accepts a dictionary mapping column names to expressions, making it the most efficient and readable option for adding multiple columns in modern PySpark
  • Always examine query plans with explain() to verify that your column additions aren’t triggering unnecessary data shuffles or multiple scans

Introduction

Adding multiple columns to PySpark DataFrames is one of the most common operations in data engineering and machine learning pipelines. Whether you’re performing feature engineering, calculating derived metrics, or transforming raw data into analysis-ready formats, you’ll frequently need to append several new columns based on existing data.

The challenge isn’t just getting the job done—it’s doing it efficiently. PySpark’s distributed nature means that poorly structured transformations can lead to multiple passes over your data, unnecessary shuffles, and degraded performance. Understanding the different approaches for adding columns and their performance implications is crucial for building scalable data pipelines.

Let’s start with a sample DataFrame that we’ll use throughout this article:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, when, expr, sqrt, udf
from pyspark.sql.types import StringType

spark = SparkSession.builder.appName("MultiColumnExample").getOrCreate()

# Sample e-commerce transaction data
data = [
    (1, "laptop", 1200, 2),
    (2, "mouse", 25, 5),
    (3, "keyboard", 75, 3),
    (4, "monitor", 300, 1),
    (5, "headphones", 150, 4)
]

df = spark.createDataFrame(data, ["id", "product", "price", "quantity"])
df.show()

Using withColumn() Method (Basic Approach)

The withColumn() method is the most straightforward way to add columns to a DataFrame. You can chain multiple calls to add several columns sequentially:

# Chaining withColumn() calls
df_with_columns = (df
    .withColumn("total_price", col("price") * col("quantity"))
    .withColumn("tax", col("total_price") * 0.08)
    .withColumn("final_price", col("total_price") + col("tax"))
    .withColumn("price_category", 
        when(col("price") < 50, "budget")
        .when(col("price") < 200, "mid-range")
        .otherwise("premium"))
)

df_with_columns.show()

This approach is readable and intuitive, especially for developers familiar with method chaining. However, it has a significant performance drawback: each withColumn() call creates a new DataFrame object. While PySpark’s lazy evaluation means the actual computation doesn’t happen until an action is triggered, the query optimizer still has to work through multiple DataFrame transformations.

For two or three columns, this overhead is negligible. But when adding dozens of columns in a production pipeline, the inefficiency becomes measurable.

Using select() with Existing and New Columns

A more efficient approach uses select() to specify both existing and new columns in a single operation. The asterisk (*) unpacks all existing columns, and you can append new column expressions:

from pyspark.sql.functions import round as spark_round

df_selected = df.select(
    "*",  # Keep all existing columns
    (col("price") * col("quantity")).alias("total_price"),
    ((col("price") * col("quantity")) * 0.08).alias("tax"),
    ((col("price") * col("quantity")) * 1.08).alias("final_price"),
    when(col("price") < 50, "budget")
        .when(col("price") < 200, "mid-range")
        .otherwise("premium").alias("price_category"),
    spark_round(sqrt(col("price")), 2).alias("price_sqrt")
)

df_selected.show()

This method processes all columns in a single pass through the data. The query optimizer can better understand the full transformation and potentially optimize the execution plan. The downside is that you can’t reference newly created columns within the same select() statement—notice how we had to recalculate total_price for the tax and final_price columns.

For complex transformations where new columns build on each other, you might need to chain multiple select() statements or use intermediate variables.

Using withColumns() Method (PySpark 3.3+)

PySpark 3.3 introduced withColumns() (note the plural), which combines the readability of withColumn() with the performance benefits of single-pass processing. It accepts a dictionary mapping column names to column expressions:

# Available in PySpark 3.3+
column_expressions = {
    "total_price": col("price") * col("quantity"),
    "tax_rate": lit(0.08),
    "discount_eligible": col("quantity") >= 3,
    "price_category": when(col("price") < 50, "budget")
                          .when(col("price") < 200, "mid-range")
                          .otherwise("premium"),
    "price_per_unit": col("price")
}

df_with_columns_dict = df.withColumns(column_expressions)
df_with_columns_dict.show()

# For dependent columns, chain withColumns() calls
df_final = (df
    .withColumns({
        "total_price": col("price") * col("quantity"),
        "tax_rate": lit(0.08)
    })
    .withColumns({
        "tax": col("total_price") * col("tax_rate"),
        "final_price": col("total_price") * (1 + col("tax_rate"))
    })
)

df_final.show()

This is now the recommended approach for adding multiple columns. It’s clean, efficient, and handles the common case elegantly. When columns depend on each other, you can chain withColumns() calls, which still performs better than chaining individual withColumn() calls.

Advanced Techniques

For complex transformations that don’t fit into standard column expressions, you have several advanced options.

User-Defined Functions (UDFs) allow custom Python logic:

# UDF for complex categorization logic
@udf(returnType=StringType())
def categorize_product(product_name, price, quantity):
    if quantity > 3 and price < 100:
        return "bulk_discount_eligible"
    elif price > 500:
        return "premium_item"
    elif product_name.lower() in ["mouse", "keyboard"]:
        return "accessory"
    else:
        return "standard"

df_with_udf = df.withColumn(
    "product_category",
    categorize_product(col("product"), col("price"), col("quantity"))
)

df_with_udf.show()

SQL expressions provide another powerful option, especially when migrating from SQL-based systems:

df_with_sql = df.select(
    "*",
    expr("price * quantity AS total_price"),
    expr("CASE WHEN quantity >= 3 THEN price * 0.9 ELSE price END AS discounted_price"),
    expr("price * quantity * 1.08 AS price_with_tax")
)

df_with_sql.show()

SQL expressions are particularly useful when you have existing SQL logic to port, or when complex CASE statements are more readable than nested when() calls.

Performance Considerations & Best Practices

Understanding the execution plan is critical for optimization. Let’s compare the physical plans:

# Chained withColumn() - multiple transformations
df_chained = (df
    .withColumn("col1", col("price") * 2)
    .withColumn("col2", col("price") * 3)
    .withColumn("col3", col("price") * 4)
)

print("=== Chained withColumn() Plan ===")
df_chained.explain()

# Single select() - single transformation
df_select = df.select(
    "*",
    (col("price") * 2).alias("col1"),
    (col("price") * 3).alias("col2"),
    (col("price") * 4).alias("col3")
)

print("\n=== Single select() Plan ===")
df_select.explain()

# withColumns() approach (PySpark 3.3+)
df_with_cols = df.withColumns({
    "col1": col("price") * 2,
    "col2": col("price") * 3,
    "col3": col("price") * 4
})

print("\n=== withColumns() Plan ===")
df_with_cols.explain()

The execution plans for select() and withColumns() will show a single Project operation, while chained withColumn() calls create nested Project operations. While Catalyst optimizer can sometimes optimize these away, it’s better to write efficient code from the start.

Key best practices:

  1. Use withColumns() for PySpark 3.3+: It’s the best balance of readability and performance.

  2. Avoid UDFs when possible: UDFs serialize data to Python, losing Spark’s optimizations. Use built-in functions whenever feasible.

  3. Cache strategically: If you’re adding columns and then performing multiple actions, cache the result:

df_enriched = df.withColumns({
    "total": col("price") * col("quantity"),
    "category": when(col("price") > 100, "high").otherwise("low")
}).cache()

# Multiple actions on the same DataFrame
df_enriched.count()
df_enriched.filter(col("category") == "high").show()
  1. Consider column pruning: Only select columns you need before adding new ones to reduce data movement:
# Better: select needed columns first
df_optimized = (df.select("id", "price", "quantity")
    .withColumns({"total": col("price") * col("quantity")})
)

Conclusion

Adding multiple columns to PySpark DataFrames is a fundamental operation with several approaches, each suited to different scenarios. For modern PySpark applications (3.3+), withColumns() is the clear winner—it’s readable, efficient, and handles most use cases elegantly. When working with older versions, prefer select() with multiple column expressions over chained withColumn() calls.

Remember that PySpark’s distributed architecture means small inefficiencies multiply across partitions and nodes. Always profile your transformations with explain(), and choose the approach that minimizes passes over your data. When in doubt, benchmark with representative data volumes—what works for 1,000 rows might not scale to 1 billion.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.