How to GroupBy and Aggregate in PySpark

GroupBy and aggregation operations form the backbone of data analysis in PySpark. Whether you're calculating total sales by region, finding average response times by service, or counting events by...

Key Insights

  • PySpark’s groupBy() combined with agg() lets you perform multiple aggregations in a single pass over your data, which is critical for performance in distributed computing
  • Always use alias() to rename aggregated columns—the default names like sum(sales) will cause headaches in downstream processing and joins
  • Watch for data skew when grouping; a single key with millions of records can bottleneck your entire job while other executors sit idle

Introduction

GroupBy and aggregation operations form the backbone of data analysis in PySpark. Whether you’re calculating total sales by region, finding average response times by service, or counting events by user, you’ll reach for groupBy() constantly.

Unlike pandas, where operations happen in memory on a single machine, PySpark distributes your data across a cluster. This means grouping operations trigger a shuffle—data with the same key must move to the same partition for aggregation. Understanding this helps you write efficient code and avoid performance pitfalls.

This article covers everything from basic syntax to performance optimization, with practical examples you can adapt for your own pipelines.

Basic GroupBy Syntax

The groupBy() method returns a GroupedData object, which you then call aggregation methods on. The simplest pattern chains a single aggregation directly.

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("GroupByDemo").getOrCreate()

# Sample sales data
data = [
    ("North", "Electronics", 1200),
    ("North", "Clothing", 800),
    ("South", "Electronics", 1500),
    ("South", "Electronics", 900),
    ("North", "Electronics", 1100),
    ("West", "Clothing", 600),
    ("South", "Clothing", 750),
    ("West", "Electronics", 1300),
]

df = spark.createDataFrame(data, ["region", "category", "amount"])

# Count transactions by region
transactions_by_region = df.groupBy("region").count()
transactions_by_region.show()

Output:

+------+-----+
|region|count|
+------+-----+
| North|    3|
| South|    3|
|  West|    2|
+------+-----+

You can swap count() for other built-in aggregations:

# Sum of sales by region
df.groupBy("region").sum("amount").show()

# Average sale amount by region
df.groupBy("region").avg("amount").show()

# Min and max in separate calls
df.groupBy("region").min("amount").show()
df.groupBy("region").max("amount").show()

This works fine for quick exploration, but real-world analysis usually requires multiple aggregations at once.

Multiple Aggregations with agg()

The agg() method accepts multiple aggregation expressions, letting you compute several metrics in one operation. Import functions from pyspark.sql.functions and pass them as arguments.

from pyspark.sql.functions import sum, avg, max, min, count

# Employee data with department info
employee_data = [
    ("Engineering", "Alice", 95000, 15000),
    ("Engineering", "Bob", 105000, 20000),
    ("Engineering", "Carol", 88000, 12000),
    ("Sales", "Dave", 72000, 25000),
    ("Sales", "Eve", 68000, 18000),
    ("Marketing", "Frank", 78000, 10000),
    ("Marketing", "Grace", 82000, 14000),
    ("Marketing", "Henry", 75000, 11000),
]

employees = spark.createDataFrame(
    employee_data, 
    ["department", "name", "salary", "bonus"]
)

# Multiple aggregations in one pass
department_stats = employees.groupBy("department").agg(
    sum("salary"),
    avg("salary"),
    max("bonus"),
    count("name")
)
department_stats.show()

Output:

+-----------+-----------+------------------+----------+-----------+
| department|sum(salary)|       avg(salary)|max(bonus)|count(name)|
+-----------+-----------+------------------+----------+-----------+
|Engineering|     288000|           96000.0|     20000|          3|
|      Sales|     140000|           70000.0|     25000|          2|
|  Marketing|     235000|78333.333333333333|     14000|          3|
+-----------+-----------+------------------+----------+-----------+

This approach is more efficient than running separate groupBy operations because Spark only shuffles the data once.

Grouping by Multiple Columns

Pass multiple column names to groupBy() for hierarchical aggregations. This is essential for time-series analysis, cohort breakdowns, and dimensional reporting.

# Sales data with temporal dimension
sales_data = [
    (2023, "Q1", "Electronics", 45000),
    (2023, "Q1", "Clothing", 32000),
    (2023, "Q2", "Electronics", 52000),
    (2023, "Q2", "Clothing", 28000),
    (2024, "Q1", "Electronics", 48000),
    (2024, "Q1", "Clothing", 35000),
    (2024, "Q2", "Electronics", 61000),
    (2024, "Q2", "Clothing", 41000),
]

sales = spark.createDataFrame(
    sales_data, 
    ["year", "quarter", "category", "revenue"]
)

# Group by year and category to see trends
yearly_category = sales.groupBy("year", "category").agg(
    sum("revenue").alias("total_revenue"),
    avg("revenue").alias("avg_quarterly_revenue")
)
yearly_category.orderBy("year", "category").show()

Output:

+----+-----------+-------------+---------------------+
|year|   category|total_revenue|avg_quarterly_revenue|
+----+-----------+-------------+---------------------+
|2023|   Clothing|        60000|              30000.0|
|2023|Electronics|        97000|              48500.0|
|2024|   Clothing|        76000|              38000.0|
|2024|Electronics|       109000|              54500.0|
+----+-----------+-------------+---------------------+

The order of columns in groupBy() doesn’t affect the result, but it does affect readability. Put the most significant dimension first.

Common Aggregation Functions

PySpark provides a rich set of aggregation functions. Here’s a practical reference with examples:

from pyspark.sql.functions import (
    sum, avg, min, max, count, 
    countDistinct, collect_list, collect_set,
    first, last, stddev, variance
)

# User activity data
activity_data = [
    ("user1", "login", 10),
    ("user1", "purchase", 50),
    ("user1", "login", 5),
    ("user2", "login", 8),
    ("user2", "view", 2),
    ("user2", "purchase", 75),
    ("user2", "login", 12),
    ("user3", "login", 15),
    ("user3", "view", 3),
]

activity = spark.createDataFrame(
    activity_data, 
    ["user_id", "action", "duration"]
)

# Demonstrate various aggregation functions
user_summary = activity.groupBy("user_id").agg(
    count("action").alias("total_actions"),
    countDistinct("action").alias("unique_actions"),
    sum("duration").alias("total_duration"),
    avg("duration").alias("avg_duration"),
    min("duration").alias("min_duration"),
    max("duration").alias("max_duration"),
    collect_list("action").alias("action_list"),
    collect_set("action").alias("action_set")
)
user_summary.show(truncate=False)

Output:

+-------+-------------+--------------+--------------+------------+------------+------------+------------------------------+----------------------+
|user_id|total_actions|unique_actions|total_duration|avg_duration|min_duration|max_duration|action_list                   |action_set            |
+-------+-------------+--------------+--------------+------------+------------+------------+------------------------------+----------------------+
|user1  |3            |2             |65            |21.666...   |5           |50          |[login, purchase, login]      |[login, purchase]     |
|user2  |4            |3             |97            |24.25       |2           |75          |[login, view, purchase, login]|[login, view, purchase]|
|user3  |2            |2             |18            |9.0         |3           |15          |[login, view]                 |[login, view]         |
+-------+-------------+--------------+--------------+------------+------------+------------+------------------------------+----------------------+

A few notes on these functions:

  • collect_list preserves duplicates and order; collect_set returns unique values
  • Both collect_* functions pull data to the driver, so use them cautiously on high-cardinality groups
  • countDistinct is expensive on large datasets—consider approx_count_distinct for estimates

Renaming and Aliasing Aggregated Columns

Default column names from aggregations are ugly and problematic. Names like sum(salary) contain parentheses that break column references in subsequent operations. Always use alias().

from pyspark.sql.functions import sum, avg, count, round

# Create a clean, report-ready summary
report = employees.groupBy("department").agg(
    count("*").alias("headcount"),
    round(avg("salary"), 2).alias("avg_salary"),
    sum("salary").alias("total_payroll"),
    round(avg("bonus"), 2).alias("avg_bonus"),
    sum("bonus").alias("total_bonus_pool")
)

# Now you can reference columns cleanly
report = report.withColumn(
    "bonus_to_salary_ratio",
    round(report.total_bonus_pool / report.total_payroll * 100, 2)
)

report.show()

Output:

+-----------+---------+----------+-------------+---------+----------------+---------------------+
| department|headcount|avg_salary|total_payroll|avg_bonus|total_bonus_pool|bonus_to_salary_ratio|
+-----------+---------+----------+-------------+---------+----------------+---------------------+
|Engineering|        3|   96000.0|       288000| 15666.67|           47000|                16.32|
|      Sales|        2|   70000.0|       140000| 21500.00|           43000|                30.71|
|  Marketing|        3|  78333.33|       235000| 11666.67|           35000|                14.89|
+-----------+---------+----------+-------------+---------+----------------+---------------------+

This pattern—aggregating with aliases, then adding derived columns—produces clean DataFrames ready for export or joining with other datasets.

Performance Considerations

GroupBy operations trigger shuffles, which are expensive. Here’s how to minimize pain:

Partition your data strategically. If you frequently group by a column, consider partitioning your data by that column when writing to storage. Subsequent reads and groupBy operations on that column avoid shuffles entirely.

Watch for skew. If one key has vastly more records than others, that partition becomes a bottleneck. Techniques include salting keys (adding random suffixes, then aggregating twice) or filtering out problematic keys for separate processing.

Use approximate functions for estimates. When exact counts aren’t necessary, approx_count_distinct provides massive speedups:

from pyspark.sql.functions import countDistinct, approx_count_distinct
import time

# Simulate a larger dataset
large_data = [(f"user_{i % 10000}", f"action_{i % 100}") 
              for i in range(1000000)]
large_df = spark.createDataFrame(large_data, ["user_id", "action"])

# Compare exact vs approximate distinct counts
start = time.time()
exact = large_df.groupBy("action").agg(
    countDistinct("user_id").alias("exact_unique_users")
)
exact.collect()
exact_time = time.time() - start

start = time.time()
approx = large_df.groupBy("action").agg(
    approx_count_distinct("user_id", 0.05).alias("approx_unique_users")
)
approx.collect()
approx_time = time.time() - start

print(f"Exact: {exact_time:.2f}s, Approximate: {approx_time:.2f}s")

The second parameter to approx_count_distinct is the relative standard deviation—0.05 means roughly 5% error, which is acceptable for most dashboards and monitoring use cases.

Filter before grouping. Reduce data volume before the shuffle whenever possible. A where() clause before groupBy() can dramatically reduce the amount of data that needs to move across the network.

GroupBy and aggregation are fundamental operations you’ll use in virtually every PySpark pipeline. Master the agg() pattern with proper aliasing, understand the shuffle implications, and you’ll write cleaner, faster distributed data processing code.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.