PySpark - GroupBy Multiple Columns

When working with large-scale data processing in PySpark, grouping by multiple columns is a fundamental operation that enables multi-dimensional analysis. Unlike single-column grouping, multi-column...

Key Insights

  • GroupBy operations on multiple columns in PySpark create hierarchical groupings that partition data by unique combinations of column values, enabling complex multi-dimensional aggregations essential for real-world analytics.
  • PySpark supports flexible syntax for multi-column grouping (both list and variadic arguments) and provides extensive aggregation functions that can be combined in a single operation for efficient computation.
  • Performance optimization requires careful consideration of data partitioning and skew, especially when grouping columns have highly uneven value distributions that can bottleneck specific executors.

Introduction & Use Cases

When working with large-scale data processing in PySpark, grouping by multiple columns is a fundamental operation that enables multi-dimensional analysis. Unlike single-column grouping, multi-column groupBy creates partitions based on unique combinations of values across all specified columns.

Common real-world scenarios include:

  • Sales analysis: Grouping by region and product category to understand revenue patterns
  • Log aggregation: Combining date and user_id to track user activity over time
  • IoT data processing: Grouping by device_id and timestamp buckets for sensor analytics
  • Customer segmentation: Analyzing behavior by demographic attributes like age_group and location

Let’s set up a sample dataset to demonstrate these concepts:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

spark = SparkSession.builder.appName("MultiColumnGroupBy").getOrCreate()

# Sample sales data
data = [
    ("North", "Electronics", "Laptop", 1200.00, 2),
    ("North", "Electronics", "Phone", 800.00, 5),
    ("North", "Clothing", "Jacket", 150.00, 10),
    ("South", "Electronics", "Laptop", 1200.00, 3),
    ("South", "Electronics", "Tablet", 500.00, 4),
    ("South", "Clothing", "Shoes", 80.00, 15),
    ("East", "Electronics", "Phone", 800.00, 6),
    ("East", "Clothing", "Jacket", 150.00, 8),
    ("West", "Electronics", "Laptop", 1200.00, 1),
    ("West", "Clothing", "Shoes", 80.00, 12)
]

schema = StructType([
    StructField("region", StringType(), True),
    StructField("category", StringType(), True),
    StructField("product", StringType(), True),
    StructField("price", DoubleType(), True),
    StructField("quantity", IntegerType(), True)
])

df = spark.createDataFrame(data, schema)
df.show()

Basic GroupBy with Multiple Columns

The syntax for grouping by multiple columns in PySpark is straightforward. You can pass column names either as separate arguments or as a list. Both approaches are functionally equivalent, though the variadic form is more commonly used.

# Method 1: Multiple arguments (preferred)
result1 = df.groupBy("region", "category").count()
result1.show()

# Method 2: List of columns
result2 = df.groupBy(["region", "category"]).count()
result2.show()

# Method 3: Using Column objects
result3 = df.groupBy(df.region, df.category).count()
result3.show()

The output shows unique combinations of region and category with their occurrence counts:

# Calculate total revenue by region and category
df_with_revenue = df.withColumn("revenue", F.col("price") * F.col("quantity"))

revenue_by_region_category = df_with_revenue.groupBy("region", "category").sum("revenue")
revenue_by_region_category.show()

This creates a new grouped DataFrame where each row represents a unique (region, category) combination with the summed revenue. The order of columns in groupBy matters for readability but not for the computational result.

Aggregation Functions on Grouped Data

PySpark provides rich aggregation capabilities that go far beyond simple counts. You can apply multiple aggregations simultaneously using the agg() method, which is more flexible than dedicated methods like sum() or avg().

# Single aggregation with dictionary syntax
single_agg = df_with_revenue.groupBy("region", "category").agg({"revenue": "sum"})
single_agg.show()

# Multiple aggregations with explicit functions
multi_agg = df_with_revenue.groupBy("region", "category").agg(
    F.sum("revenue").alias("total_revenue"),
    F.avg("price").alias("avg_price"),
    F.count("*").alias("transaction_count"),
    F.min("quantity").alias("min_quantity"),
    F.max("quantity").alias("max_quantity")
)
multi_agg.show()

Using alias() is crucial for creating readable column names. Without it, PySpark generates names like sum(revenue) which are harder to reference in downstream operations:

# Calculate revenue metrics by region and product
detailed_metrics = df_with_revenue.groupBy("region", "product").agg(
    F.sum("revenue").alias("total_revenue"),
    F.avg("revenue").alias("avg_revenue"),
    F.count("product").alias("sales_count")
)

# Now you can easily reference these columns
top_products = detailed_metrics.filter(F.col("total_revenue") > 2000)
top_products.show()

Advanced Aggregations with Custom Logic

Beyond basic statistical functions, PySpark offers specialized aggregations for collecting values, counting distinct elements, and applying custom logic.

# Collect all products sold in each region-category combination
collected_products = df.groupBy("region", "category").agg(
    F.collect_list("product").alias("products_list"),
    F.collect_set("product").alias("unique_products")
)
collected_products.show(truncate=False)

# Count distinct products per region
distinct_counts = df.groupBy("region").agg(
    F.countDistinct("product").alias("unique_product_count"),
    F.countDistinct("category").alias("unique_category_count")
)
distinct_counts.show()

The difference between collect_list and collect_set is important: collect_list preserves duplicates and order (though order isn’t guaranteed in distributed systems), while collect_set returns only unique values.

For complex business logic, combine multiple aggregation types:

# Comprehensive analysis by region and category
comprehensive = df_with_revenue.groupBy("region", "category").agg(
    F.sum("revenue").alias("total_revenue"),
    F.avg("price").alias("avg_price"),
    F.count("*").alias("num_transactions"),
    F.collect_set("product").alias("products_offered"),
    F.min("quantity").alias("min_order_size"),
    F.max("quantity").alias("max_order_size"),
    F.stddev("price").alias("price_stddev")
)
comprehensive.show(truncate=False)

Sorting and Filtering Grouped Results

After aggregation, you’ll often need to sort results or filter based on aggregated values (similar to SQL’s HAVING clause).

# Sort by total revenue descending
sorted_results = df_with_revenue.groupBy("region", "category").agg(
    F.sum("revenue").alias("total_revenue"),
    F.count("*").alias("transaction_count")
).orderBy(F.desc("total_revenue"))

sorted_results.show()

# Filter aggregated results (HAVING equivalent)
high_revenue = df_with_revenue.groupBy("region", "category").agg(
    F.sum("revenue").alias("total_revenue"),
    F.avg("price").alias("avg_price")
).filter(F.col("total_revenue") > 1500)

high_revenue.show()

# Complex filtering with multiple conditions
filtered = df_with_revenue.groupBy("region", "category").agg(
    F.sum("revenue").alias("total_revenue"),
    F.count("*").alias("transaction_count")
).filter(
    (F.col("total_revenue") > 1000) & (F.col("transaction_count") >= 2)
).orderBy(F.desc("total_revenue"))

filtered.show()

You can also sort by multiple columns, including the grouping columns themselves:

# Sort by region, then by revenue within each region
multi_sort = df_with_revenue.groupBy("region", "category").agg(
    F.sum("revenue").alias("total_revenue")
).orderBy("region", F.desc("total_revenue"))

multi_sort.show()

Performance Considerations

GroupBy operations can be expensive, especially with skewed data. When one or more group keys have disproportionately many records, a single executor handles that partition, creating a bottleneck.

# Check partition distribution before grouping
print(f"Number of partitions: {df.rdd.getNumPartitions()}")

# Repartition based on grouping columns for better distribution
optimized_df = df.repartition("region", "category")

result = optimized_df.groupBy("region", "category").agg(
    F.sum(F.col("price") * F.col("quantity")).alias("total_revenue")
)

For very large datasets, consider:

  1. Salting: Add a random suffix to skewed keys to distribute load
  2. Broadcast joins: If joining with small dimension tables before grouping
  3. Partial aggregation: Use reduceByKey for associative operations
# Example: Increase parallelism for heavy aggregations
df_repartitioned = df.repartition(200, "region", "category")

# For skewed data, consider coalesce after aggregation
result = df_repartitioned.groupBy("region", "category").agg(
    F.sum("quantity").alias("total_quantity")
).coalesce(10)  # Reduce partitions after aggregation

Common Pitfalls & Best Practices

Handling Null Values: Null values in grouping columns create a separate group. Decide whether to filter them out or replace them before grouping.

# Filter out nulls before grouping
cleaned_df = df.filter(F.col("region").isNotNull() & F.col("category").isNotNull())
result = cleaned_df.groupBy("region", "category").count()

# Or replace nulls with a default value
filled_df = df.fillna({"region": "Unknown", "category": "Uncategorized"})
result = filled_df.groupBy("region", "category").count()

Column Naming: Always use alias() for aggregated columns to avoid unwieldy names and potential conflicts.

GroupBy vs. Window Functions: Use groupBy when you need to reduce the dataset. Use window functions when you need to retain all rows while adding aggregated values.

# GroupBy reduces rows
grouped = df.groupBy("region").agg(F.sum("quantity").alias("region_total"))

# Window function retains all rows
from pyspark.sql.window import Window

windowSpec = Window.partitionBy("region")
with_totals = df.withColumn("region_total", F.sum("quantity").over(windowSpec))

Memory Management: Collecting large arrays with collect_list can cause memory issues. Limit the size or use collect_set when appropriate.

Multi-column groupBy operations are powerful tools in the PySpark arsenal. Master these patterns, understand the performance implications, and you’ll handle complex analytical queries efficiently across massive datasets.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.