PySpark - Get Column Names as List

Working with PySpark DataFrames frequently requires programmatic access to column names. Whether you're building dynamic ETL pipelines, validating schemas across environments, or implementing...

Key Insights

  • PySpark offers two primary methods to extract column names: df.columns (returns a list directly) and df.schema.names (extracts from schema metadata), with columns being simpler for most use cases
  • Column name extraction enables powerful dynamic operations like pattern-based filtering, data type-specific selection, and programmatic schema validation in ETL pipelines
  • Combining column name extraction with list comprehensions and filtering techniques allows you to build flexible, maintainable data transformation logic that adapts to schema changes

Introduction

Working with PySpark DataFrames frequently requires programmatic access to column names. Whether you’re building dynamic ETL pipelines, validating schemas across environments, or implementing flexible transformations that adapt to changing data structures, extracting column names as a list is a fundamental operation you’ll use repeatedly.

Unlike working with small datasets where you can manually reference columns, production Spark applications often deal with DataFrames containing dozens or hundreds of columns. You might need to apply transformations to all columns matching a pattern, exclude sensitive fields programmatically, or validate that expected columns exist before running expensive operations. Hard-coding column names creates brittle code that breaks when schemas evolve.

Let’s explore the practical methods for extracting column names from PySpark DataFrames and how to leverage them in real-world scenarios.

First, let’s create a sample DataFrame to work with throughout this article:

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

spark = SparkSession.builder.appName("ColumnNamesExample").getOrCreate()

data = [
    (1, "Alice", "Engineering", 95000.0, "alice@example.com"),
    (2, "Bob", "Sales", 75000.0, "bob@example.com"),
    (3, "Charlie", "Engineering", 105000.0, "charlie@example.com"),
    (4, "Diana", "Marketing", 85000.0, "diana@example.com")
]

schema = StructType([
    StructField("user_id", IntegerType(), True),
    StructField("user_name", StringType(), True),
    StructField("department", StringType(), True),
    StructField("salary", DoubleType(), True),
    StructField("email", StringType(), True)
])

df = spark.createDataFrame(data, schema)
df.show()

Using the columns Attribute

The simplest and most direct method for extracting column names is using the columns attribute. This returns a Python list of column names as strings, making it immediately usable with standard Python operations.

# Get column names as a list
column_names = df.columns
print(column_names)
# Output: ['user_id', 'user_name', 'department', 'salary', 'email']

print(type(column_names))
# Output: <class 'list'>

This approach is straightforward and requires no additional method calls. The returned list can be iterated, sliced, or manipulated using standard Python list operations:

# Print each column name
for col in df.columns:
    print(f"Column: {col}")

# Get first three columns
first_three = df.columns[:3]
print(first_three)
# Output: ['user_id', 'user_name', 'department']

# Check if a column exists
if 'salary' in df.columns:
    print("Salary column found")

# Count columns
num_columns = len(df.columns)
print(f"DataFrame has {num_columns} columns")

The columns attribute is the recommended approach for most scenarios due to its simplicity and readability. It’s particularly useful when you need to perform quick checks or transformations on column names without diving into schema metadata.

Using schema.names Method

The alternative approach uses df.schema.names, which extracts column names from the DataFrame’s schema object. This method produces identical results to df.columns for practical purposes:

# Extract column names from schema
schema_names = df.schema.names
print(schema_names)
# Output: ['user_id', 'user_name', 'department', 'salary', 'email']

# Verify they're identical
print(df.columns == df.schema.names)
# Output: True

While functionally equivalent for extracting column names, schema.names can be more semantically appropriate when you’re already working with schema operations. If you’re inspecting data types, nullable properties, or other schema attributes, using schema.names maintains consistency in your code:

# Working with schema metadata
for field in df.schema.fields:
    print(f"Column: {field.name}, Type: {field.dataType}, Nullable: {field.nullable}")

# In this context, schema.names fits naturally
all_columns = df.schema.names

For most use cases, choose df.columns for its brevity. Use df.schema.names when you’re performing broader schema analysis where accessing the schema object provides additional context.

Extracting Specific Column Names with Filtering

The real power of extracting column names emerges when you combine it with filtering techniques. This enables dynamic column selection based on naming patterns or data types.

Filter by Name Pattern

Use list comprehensions to filter columns matching specific patterns:

# Get all columns starting with 'user_'
user_columns = [col for col in df.columns if col.startswith('user_')]
print(user_columns)
# Output: ['user_id', 'user_name']

# Get columns containing specific substring
contact_columns = [col for col in df.columns if 'email' in col or 'phone' in col]
print(contact_columns)
# Output: ['email']

# Exclude specific columns
non_sensitive_columns = [col for col in df.columns if col not in ['salary', 'email']]
print(non_sensitive_columns)
# Output: ['user_id', 'user_name', 'department']

Filter by Data Type

Filtering columns by data type requires accessing the schema, but enables powerful type-based operations:

from pyspark.sql.types import StringType, IntegerType, DoubleType

# Get all string columns
string_columns = [field.name for field in df.schema.fields 
                  if isinstance(field.dataType, StringType)]
print(string_columns)
# Output: ['user_name', 'department', 'email']

# Get all numeric columns
numeric_columns = [field.name for field in df.schema.fields 
                   if isinstance(field.dataType, (IntegerType, DoubleType))]
print(numeric_columns)
# Output: ['user_id', 'salary']

# Get non-string columns
non_string_columns = [field.name for field in df.schema.fields 
                      if not isinstance(field.dataType, StringType)]
print(non_string_columns)
# Output: ['user_id', 'salary']

Practical Applications

Column name extraction becomes invaluable in real-world scenarios where hard-coding column references creates maintenance nightmares.

Dynamic Column Selection

Select multiple columns programmatically without manually listing each one:

# Select all user-related columns dynamically
user_cols = [col for col in df.columns if col.startswith('user_')]
user_df = df.select(user_cols)
user_df.show()

# Select all columns except sensitive ones
sensitive = ['salary', 'email']
public_cols = [col for col in df.columns if col not in sensitive]
public_df = df.select(public_cols)
public_df.show()

Programmatic Column Exclusion

Drop multiple columns matching a pattern without listing each individually:

# Drop all columns containing 'user_'
cols_to_keep = [col for col in df.columns if 'user_' not in col]
cleaned_df = df.select(cols_to_keep)
cleaned_df.show()

# Alternative using drop (when you know what to remove)
user_cols = [col for col in df.columns if col.startswith('user_')]
cleaned_df_alt = df.drop(*user_cols)
cleaned_df_alt.show()

Schema Validation

Validate expected columns exist before running transformations:

def validate_schema(df, required_columns):
    """Validate that all required columns exist in DataFrame"""
    existing_columns = df.columns
    missing_columns = [col for col in required_columns if col not in existing_columns]
    
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    
    return True

# Usage
required = ['user_id', 'user_name', 'department']
try:
    validate_schema(df, required)
    print("Schema validation passed")
except ValueError as e:
    print(f"Schema validation failed: {e}")

Bulk Column Transformations

Apply transformations to columns matching specific criteria:

from pyspark.sql.functions import lower, trim

# Convert all string columns to lowercase and trim
string_cols = [field.name for field in df.schema.fields 
               if isinstance(field.dataType, StringType)]

# Create lowercase and trimmed versions
transformed_df = df
for col_name in string_cols:
    transformed_df = transformed_df.withColumn(col_name, trim(lower(df[col_name])))

transformed_df.show()

Performance Considerations and Best Practices

Extracting column names is a metadata operation that doesn’t trigger Spark jobs or scan data, making it extremely fast regardless of DataFrame size. Both df.columns and df.schema.names are O(1) operations that simply return cached metadata.

However, keep these best practices in mind:

Cache column name lists when reusing them: If you’re referencing column names multiple times in a pipeline, extract them once and store in a variable rather than repeatedly calling df.columns.

# Good - extract once
cols = df.columns
filtered_cols = [c for c in cols if c.startswith('user_')]
result = df.select(filtered_cols)

# Less efficient - multiple extractions
result = df.select([c for c in df.columns if c.startswith('user_')])

Prefer columns over schema.names for simplicity: Unless you’re working extensively with schema metadata, df.columns is more readable and conventional.

Use type-based filtering judiciously: While powerful, filtering by data type requires iterating through schema fields. For DataFrames with hundreds of columns, consider whether you truly need type-based filtering or if name-based patterns suffice.

Validate schemas early: In production ETL pipelines, validate expected columns exist immediately after reading data. This fails fast rather than encountering errors deep in transformation logic.

Column name extraction is a simple but essential technique in PySpark development. Master these patterns and you’ll write more flexible, maintainable data pipelines that gracefully handle schema evolution and reduce brittle hard-coded column references.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.