How to Select Columns in PySpark

Column selection is the most fundamental DataFrame operation you'll perform in PySpark. Whether you're preparing data for a machine learning pipeline, reducing memory footprint before a join, or...

Key Insights

  • Use select() with string column names for simple queries, but switch to col() when you need transformations, aliasing, or programmatic column selection
  • Dynamic column selection with list unpacking (*column_list) is essential for building reusable data pipelines that adapt to changing schemas
  • Column pruning through selective projection isn’t just about cleaner code—it triggers Spark’s predicate pushdown optimization, significantly reducing I/O on large datasets

Introduction

Column selection is the most fundamental DataFrame operation you’ll perform in PySpark. Whether you’re preparing data for a machine learning pipeline, reducing memory footprint before a join, or simply extracting the fields you need for analysis, selecting columns efficiently matters.

Unlike pandas, where column selection is nearly instantaneous on in-memory data, PySpark operates on distributed datasets that may span terabytes across hundreds of nodes. Selecting only the columns you need isn’t just good practice—it’s a performance imperative. Spark’s optimizer can push column projections down to the data source, meaning you avoid reading unnecessary data from disk entirely.

This article covers every practical approach to column selection in PySpark, from basic string-based selection to dynamic patterns and nested struct access.

Basic Column Selection with select()

The select() method is your primary tool for column selection. It accepts column names as strings, Column objects, or a mix of both.

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.appName("ColumnSelection").getOrCreate()

# Sample DataFrame
data = [
    ("Alice", 30, "Engineering", 95000),
    ("Bob", 25, "Marketing", 65000),
    ("Charlie", 35, "Engineering", 110000)
]
df = spark.createDataFrame(data, ["name", "age", "department", "salary"])

# String-based selection (simplest approach)
df.select("name", "department").show()

# Using Column objects via df.column_name
df.select(df.name, df.department).show()

# Using col() function
df.select(col("name"), col("department")).show()

All three approaches produce identical results. Use string names when you’re writing quick, exploratory code. Switch to col() when you need to chain transformations.

# String selection breaks when you need transformations
df.select("name", "salary" * 1.1)  # This fails

# col() enables inline operations
df.select(col("name"), col("salary") * 1.1).show()

Selecting Columns with col() and df[] Notation

PySpark offers multiple syntaxes for referencing columns. Understanding when to use each prevents subtle bugs.

# Dot notation - clean but limited
df.select(df.name, df.age).show()

# Bracket notation - handles special characters
df.select(df["name"], df["age"]).show()

# col() function - most flexible
df.select(col("name"), col("age")).show()

When to use each approach:

Dot notation (df.column_name) is readable but fails on column names with spaces, special characters, or names that conflict with DataFrame methods. Never use it in production pipelines.

# Dot notation fails here
df_with_spaces = df.toDF("full name", "age", "dept", "annual salary")
# df_with_spaces.select(df_with_spaces.full name)  # SyntaxError

# Bracket notation handles it
df_with_spaces.select(df_with_spaces["full name"]).show()

# col() also works
df_with_spaces.select(col("full name")).show()

The col() function is the most versatile. It works across DataFrames (useful in joins), handles any column name, and integrates cleanly with transformations. Make it your default.

from pyspark.sql.functions import col, upper, round

# col() chains naturally with transformations
df.select(
    col("name"),
    upper(col("department")).alias("dept_upper"),
    round(col("salary") / 12, 2).alias("monthly_salary")
).show()

Selecting Multiple Columns Dynamically

Hardcoding column names works for ad-hoc analysis, but production pipelines need dynamic selection. Python’s list unpacking makes this straightforward.

# Define columns as a list
columns_to_select = ["name", "department", "salary"]

# Unpack with asterisk
df.select(*columns_to_select).show()

# Works with col() objects too
col_objects = [col(c) for c in columns_to_select]
df.select(*col_objects).show()

Selecting columns by pattern:

# Select all columns that start with a prefix
salary_cols = [c for c in df.columns if c.startswith("sal")]
df.select(*salary_cols).show()

# Select columns matching a regex pattern
import re
pattern = re.compile(r"^(name|age)$")
matched_cols = [c for c in df.columns if pattern.match(c)]
df.select(*matched_cols).show()

Selecting all columns except specific ones:

# Exclude specific columns
exclude = {"age", "salary"}
remaining_cols = [c for c in df.columns if c not in exclude]
df.select(*remaining_cols).show()

# Alternative using drop() - often cleaner
df.drop("age", "salary").show()

For exclusion patterns, drop() is usually more readable than filtered select(). Use select() when you’re also transforming columns; use drop() for pure exclusion.

Aliasing and Transforming During Selection

The alias() method renames columns during selection. This is cleaner than selecting then renaming in a separate step.

from pyspark.sql.functions import col, upper, concat, lit

# Basic aliasing
df.select(
    col("name").alias("employee_name"),
    col("department").alias("dept")
).show()

# Combine transformations with aliasing
df.select(
    concat(col("name"), lit(" ("), col("department"), lit(")")).alias("employee_info"),
    (col("salary") * 1.1).alias("salary_with_raise"),
    (col("age") + 5).alias("age_in_5_years")
).show()

For complex transformations, consider using withColumn() for readability, but know that chaining multiple withColumn() calls is less efficient than a single select() with all transformations.

# Less efficient - multiple withColumn calls
result = (df
    .withColumn("salary_adjusted", col("salary") * 1.1)
    .withColumn("name_upper", upper(col("name")))
    .select("name_upper", "salary_adjusted"))

# More efficient - single select with transformations
result = df.select(
    upper(col("name")).alias("name_upper"),
    (col("salary") * 1.1).alias("salary_adjusted")
)

Selecting Nested and Struct Columns

Real-world data often contains nested structures—JSON fields, arrays, and maps. PySpark handles these with dot notation and specialized functions.

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType

# Create DataFrame with nested structure
schema = StructType([
    StructField("id", IntegerType()),
    StructField("info", StructType([
        StructField("name", StringType()),
        StructField("email", StringType())
    ])),
    StructField("scores", ArrayType(IntegerType()))
])

nested_data = [
    (1, {"name": "Alice", "email": "alice@example.com"}, [85, 90, 78]),
    (2, {"name": "Bob", "email": "bob@example.com"}, [92, 88, 95])
]

nested_df = spark.createDataFrame(nested_data, schema)

# Access struct fields with dot notation
nested_df.select("id", "info.name", "info.email").show()

# Access array elements by index
nested_df.select("id", col("scores")[0].alias("first_score")).show()

# Use getField() for programmatic access
field_name = "name"
nested_df.select(col("info").getField(field_name).alias("extracted_name")).show()

For deeply nested structures, getField() is essential when field names come from variables:

# Dynamic nested field access
fields_to_extract = ["name", "email"]
for field in fields_to_extract:
    nested_df.select(col("info").getField(field).alias(field)).show()

When working with arrays, use explode() to flatten before selecting:

from pyspark.sql.functions import explode

nested_df.select("id", explode("scores").alias("score")).show()

Performance Considerations

Column selection directly impacts query performance through column pruning. When you select specific columns, Spark’s Catalyst optimizer pushes this projection down to the data source.

For Parquet and ORC files: Only the selected columns are read from disk. On wide tables with hundreds of columns, this can reduce I/O by orders of magnitude.

For JDBC sources: The generated SQL query includes only selected columns, reducing network transfer.

# Reading all columns then filtering - inefficient
df_full = spark.read.parquet("large_dataset.parquet")
df_subset = df_full.select("col1", "col2")

# Spark optimizes this automatically, but explicit early selection
# makes intent clear and ensures optimization
df_optimized = spark.read.parquet("large_dataset.parquet").select("col1", "col2")

Best practices for large datasets:

  1. Select columns as early as possible in your pipeline
  2. Avoid select("*") unless you genuinely need all columns
  3. When joining DataFrames, select only the columns you need before the join
  4. Use drop() to remove columns you’ve finished using mid-pipeline
# Before a join, trim to necessary columns
df1_slim = df1.select("id", "value1")
df2_slim = df2.select("id", "value2")
joined = df1_slim.join(df2_slim, "id")

Column selection in PySpark is simple in concept but nuanced in practice. Master col() for flexibility, embrace dynamic selection for maintainable pipelines, and always think about what columns you actually need. Your cluster’s memory and your query times will thank you.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.