PySpark - GroupBy with Aggregation Functions

GroupBy operations are fundamental to data analysis, and in PySpark, they're your primary tool for summarizing distributed datasets. Unlike pandas where groupBy works on a single machine, PySpark...

Key Insights

  • PySpark’s groupBy operations follow the split-apply-combine pattern, distributing aggregation work across cluster nodes for massive performance gains on large datasets compared to pandas.
  • The agg() method is your power tool—it allows multiple aggregations simultaneously with clean syntax, avoiding the performance penalty of multiple separate groupBy operations.
  • Data skew in group keys can cripple performance; understanding partitioning and using techniques like salting or repartitioning before aggregation is critical for production workloads.

Introduction to GroupBy Operations in PySpark

GroupBy operations are fundamental to data analysis, and in PySpark, they’re your primary tool for summarizing distributed datasets. Unlike pandas where groupBy works on a single machine, PySpark distributes the split-apply-combine pattern across your cluster: data is split by group keys, aggregation functions are applied to each partition, and results are combined.

This distributed approach means you can aggregate terabytes of data efficiently, but it also requires understanding how PySpark optimizes these operations. Let’s start with a realistic dataset we’ll use throughout this article:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

spark = SparkSession.builder.appName("GroupByAggregations").getOrCreate()

# Sample sales data
data = [
    ("Electronics", "Laptop", "North", 1200.00, "2024-01-15"),
    ("Electronics", "Mouse", "North", 25.00, "2024-01-16"),
    ("Electronics", "Laptop", "South", 1200.00, "2024-01-16"),
    ("Clothing", "Shirt", "North", 45.00, "2024-01-15"),
    ("Clothing", "Pants", "South", 65.00, "2024-01-17"),
    ("Electronics", "Keyboard", "East", 75.00, "2024-01-17"),
    ("Clothing", "Shirt", "East", 45.00, "2024-01-18"),
    ("Electronics", "Mouse", "North", 25.00, "2024-01-18"),
    ("Clothing", "Pants", "North", 65.00, "2024-01-19"),
]

schema = StructType([
    StructField("category", StringType(), True),
    StructField("product", StringType(), True),
    StructField("region", StringType(), True),
    StructField("amount", DoubleType(), True),
    StructField("date", StringType(), True)
])

df = spark.createDataFrame(data, schema)
df.show()

Basic GroupBy with Single Aggregation

The simplest groupBy operation applies a single aggregation function. PySpark provides built-in functions like count(), sum(), avg(), min(), and max() that work directly on grouped data.

# Count transactions per category
category_counts = df.groupBy("category").count()
category_counts.show()
# Output:
# +-----------+-----+
# |   category|count|
# +-----------+-----+
# |  Clothing|    4|
# |Electronics|    5|
# +-----------+-----+

# Calculate average amount per region
region_avg = df.groupBy("region").avg("amount")
region_avg.show()
# Output:
# +------+-----------+
# |region|avg(amount)|
# +------+-----------+
# | South|     643.33|
# | North|     285.00|
# |  East|      60.00|
# +------+-----------+

# Find maximum sale amount per category
category_max = df.groupBy("category").max("amount")
category_max.show()

# Total revenue per region
region_sum = df.groupBy("region").sum("amount")
region_sum.show()

Each of these operations creates a new DataFrame with the group key column(s) and the aggregated result. Notice the default column naming like avg(amount)—we’ll fix that shortly.

Multiple Aggregations on Grouped Data

Running separate groupBy operations is inefficient because each one triggers a separate shuffle operation. Instead, use the agg() method to apply multiple aggregations in a single pass:

# Multiple aggregations using agg() with dictionary syntax
category_stats = df.groupBy("category").agg({
    "amount": "sum",
    "product": "count",
    "amount": "avg"  # Note: duplicate keys overwrite in dict syntax
})
category_stats.show()

# Better approach: use Column expressions with alias for clarity
category_stats = df.groupBy("category").agg(
    F.sum("amount").alias("total_revenue"),
    F.avg("amount").alias("avg_sale"),
    F.count("product").alias("num_transactions"),
    F.min("amount").alias("min_sale"),
    F.max("amount").alias("max_sale")
)
category_stats.show()
# Output:
# +-----------+-------------+--------+----------------+--------+--------+
# |   category|total_revenue|avg_sale|num_transactions|min_sale|max_sale|
# +-----------+-------------+--------+----------------+--------+--------+
# |  Clothing|       220.00|   55.00|               4|   45.00|   65.00|
# |Electronics|      2525.00|  505.00|               5|   25.00| 1200.00|
# +-----------+-------------+--------+----------------+--------+--------+

# Multi-level aggregations with different columns
detailed_stats = df.groupBy("category", "region").agg(
    F.sum("amount").alias("revenue"),
    F.countDistinct("product").alias("unique_products")
)
detailed_stats.show()

The Column expression approach with alias() is clearer and more flexible. Always import pyspark.sql.functions as F for these operations.

Advanced Aggregation Functions

Beyond basic statistics, PySpark offers specialized aggregation functions for complex scenarios:

# Collect all products sold per category (as list)
category_products = df.groupBy("category").agg(
    F.collect_list("product").alias("all_products"),
    F.collect_set("product").alias("unique_products")
)
category_products.show(truncate=False)
# Output shows arrays of products

# Count distinct products efficiently
distinct_counts = df.groupBy("category").agg(
    F.countDistinct("product").alias("exact_distinct"),
    F.approx_count_distinct("product").alias("approx_distinct")
)
distinct_counts.show()

# Get first and last transaction dates per category
# Note: first() and last() depend on DataFrame order, so sort first
from pyspark.sql import Window

date_ranges = df.groupBy("category").agg(
    F.min("date").alias("first_sale"),
    F.max("date").alias("last_sale"),
    F.first("product").alias("first_product")  # arbitrary without order
)
date_ranges.show()

# More useful: combine with window functions for ordered first/last
windowSpec = Window.partitionBy("category").orderBy("date")
df_with_rank = df.withColumn("rank", F.row_number().over(windowSpec))
first_products = df_with_rank.filter(F.col("rank") == 1).select("category", "product")
first_products.show()

Use collect_list() and collect_set() carefully—they bring all group values into memory on a single executor. For large groups, this can cause out-of-memory errors. approx_count_distinct() uses HyperLogLog algorithm and is much faster for large datasets with acceptable accuracy trade-offs.

GroupBy with Multiple Columns and Filtering

Real-world analysis often requires grouping by multiple dimensions and filtering aggregated results:

# Group by multiple columns
multi_group = df.groupBy("category", "region").agg(
    F.sum("amount").alias("total_sales"),
    F.count("*").alias("transaction_count")
)
multi_group.show()

# Filter aggregated results (equivalent to SQL HAVING clause)
high_value_segments = df.groupBy("category", "region").agg(
    F.sum("amount").alias("total_sales")
).filter(F.col("total_sales") > 100)
high_value_segments.show()

# Combine groupBy with window functions for percentage calculations
category_totals = df.groupBy("category").agg(
    F.sum("amount").alias("category_total")
)

# Join back to get percentage of category total per region
region_category = df.groupBy("category", "region").agg(
    F.sum("amount").alias("region_sales")
)

result = region_category.join(category_totals, "category")
result = result.withColumn(
    "percentage", 
    F.round((F.col("region_sales") / F.col("category_total")) * 100, 2)
)
result.show()

The filter after aggregation is crucial—applying it before groupBy would filter individual rows, not aggregated groups.

Performance Considerations and Best Practices

GroupBy operations trigger a shuffle, redistributing data across executors. Poor partitioning can create severe performance bottlenecks:

# Check current partitioning
print(f"Number of partitions: {df.rdd.getNumPartitions()}")

# Repartition before groupBy if you have skewed data
df_repartitioned = df.repartition(10, "category")
result = df_repartitioned.groupBy("category").agg(
    F.sum("amount").alias("total")
)

# Compare execution plans
df.groupBy("category").sum("amount").explain()
df_repartitioned.groupBy("category").sum("amount").explain()

# For highly skewed data, use salting technique
# Add random salt to distribute skewed keys
df_salted = df.withColumn("salt", (F.rand() * 10).cast("int"))
df_salted = df_salted.withColumn("salted_key", F.concat(F.col("category"), F.lit("_"), F.col("salt")))

# Group by salted key, then aggregate again
intermediate = df_salted.groupBy("salted_key", "category").agg(
    F.sum("amount").alias("partial_sum")
)
final_result = intermediate.groupBy("category").agg(
    F.sum("partial_sum").alias("total_amount")
)
final_result.show()

Use explain() to understand query plans. Look for “Exchange” operations indicating shuffles. If one partition is much larger than others (data skew), salting distributes the load more evenly.

Common Pitfalls and Troubleshooting

Avoid these common mistakes when working with groupBy aggregations:

# PITFALL 1: Nulls in aggregations
data_with_nulls = [
    ("A", 100), ("A", None), ("B", 200), ("B", 300)
]
df_nulls = spark.createDataFrame(data_with_nulls, ["group", "value"])

# sum() ignores nulls, but count() includes them
df_nulls.groupBy("group").agg(
    F.sum("value").alias("total"),
    F.count("value").alias("count_non_null"),
    F.count("*").alias("count_all")
).show()

# PITFALL 2: Collecting large grouped data
# DON'T DO THIS on large datasets:
# large_groups = df.groupBy("category").agg(F.collect_list("product"))
# large_groups.collect()  # Can cause OOM errors

# Instead, write to storage or use take()
# large_groups.write.parquet("output/grouped_data")

# PITFALL 3: Incorrect aggregation syntax
# WRONG: df.groupBy("category").agg("sum(amount)")  # String not recognized
# RIGHT:
df.groupBy("category").agg(F.sum("amount")).show()

# PITFALL 4: Forgetting to handle duplicate column names
# This fails:
# df.groupBy("category").agg(F.sum("amount"), F.avg("amount"))
# Both create "amount" column

# Use alias:
df.groupBy("category").agg(
    F.sum("amount").alias("total"),
    F.avg("amount").alias("average")
).show()

For custom aggregation logic beyond built-in functions, consider User Defined Aggregate Functions (UDAFs), though they’re more complex and often slower than built-in functions. In most cases, you can achieve your goal by combining built-in aggregations creatively.

Master these groupBy patterns and you’ll handle most real-world data aggregation scenarios efficiently. The key is understanding how PySpark distributes work and choosing the right aggregation strategy for your data characteristics and cluster resources.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.