How to GroupBy in PySpark
GroupBy operations are the backbone of data analysis in PySpark. Whether you're calculating sales totals by region, counting user events by session, or computing average response times by service,...
Key Insights
- PySpark’s
groupBy()combined withagg()gives you fine-grained control over multiple aggregations in a single pass, which is essential for efficient distributed processing. - Always prefer built-in aggregate functions from
pyspark.sql.functionsover UDFs—they’re optimized for Spark’s execution engine and avoid Python serialization overhead. - Data skew is the silent killer of GroupBy performance; monitor partition sizes and consider salting techniques when one key dominates your dataset.
Introduction
GroupBy operations are the backbone of data analysis in PySpark. Whether you’re calculating sales totals by region, counting user events by session, or computing average response times by service, you’ll reach for groupBy() constantly.
Unlike pandas, where groupby operations happen in memory on a single machine, PySpark distributes the work across a cluster. This means your data gets shuffled between nodes—an expensive operation that you need to understand to write performant code.
This article covers everything from basic syntax to advanced patterns, with a focus on what actually matters in production: getting correct results efficiently.
Basic GroupBy Syntax
The groupBy() method returns a GroupedData object, which you then chain with an aggregation method. Let’s start with a sample dataset:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("GroupByDemo").getOrCreate()
# Sample sales data
data = [
("Electronics", "North", 1200, 5),
("Electronics", "South", 800, 3),
("Electronics", "North", 1500, 7),
("Clothing", "North", 400, 10),
("Clothing", "South", 600, 15),
("Clothing", "South", 350, 8),
("Furniture", "North", 2000, 2),
("Furniture", "South", 1800, 3),
]
df = spark.createDataFrame(data, ["category", "region", "revenue", "quantity"])
Basic aggregations are straightforward:
# Count rows per category
df.groupBy("category").count().show()
# Sum revenue per category
df.groupBy("category").sum("revenue").show()
# Average quantity per category
df.groupBy("category").avg("quantity").show()
These single-aggregation methods work fine for quick exploration, but production code typically needs more. The column names they generate (sum(revenue), avg(quantity)) are awkward and inconsistent across Spark versions.
Multiple Aggregations with agg()
The agg() method is what you’ll use in real applications. It accepts multiple aggregate expressions and lets you control output column names:
from pyspark.sql import functions as F
category_stats = df.groupBy("category").agg(
F.sum("revenue").alias("total_revenue"),
F.avg("revenue").alias("avg_revenue"),
F.max("revenue").alias("max_sale"),
F.min("revenue").alias("min_sale"),
F.count("*").alias("transaction_count"),
F.sum("quantity").alias("total_units_sold")
)
category_stats.show()
Output:
+-----------+-------------+-----------+--------+--------+-----------------+----------------+
| category|total_revenue|avg_revenue|max_sale|min_sale|transaction_count|total_units_sold|
+-----------+-------------+-----------+--------+--------+-----------------+----------------+
|Electronics| 3500| 1166.7| 1500| 800| 3| 15|
| Clothing| 1350| 450.0| 600| 350| 3| 33|
| Furniture| 3800| 1900.0| 2000| 1800| 2| 5|
+-----------+-------------+-----------+--------+--------+-----------------+----------------+
The alias() method is non-negotiable in production code. Without it, you’re left with column names like sum(revenue) that break downstream queries and make your code fragile.
You can also use a dictionary syntax, though I find it less readable:
# Alternative dictionary syntax
df.groupBy("category").agg({
"revenue": "sum",
"quantity": "avg"
}).show()
Stick with the explicit F.function().alias() pattern—it’s clearer and more flexible.
Grouping by Multiple Columns
Real-world analysis rarely groups by a single column. Pass multiple columns to groupBy() to create composite groups:
regional_category_stats = df.groupBy("category", "region").agg(
F.sum("revenue").alias("total_revenue"),
F.sum("quantity").alias("total_quantity"),
F.count("*").alias("num_transactions"),
F.round(F.avg("revenue"), 2).alias("avg_transaction_value")
)
regional_category_stats.orderBy("category", "region").show()
Output:
+-----------+------+-------------+--------------+----------------+---------------------+
| category|region|total_revenue|total_quantity|num_transactions|avg_transaction_value|
+-----------+------+-------------+--------------+----------------+---------------------+
| Clothing| North| 400| 10| 1| 400.00|
| Clothing| South| 950| 23| 2| 475.00|
|Electronics| North| 2700| 12| 2| 1350.00|
|Electronics| South| 800| 3| 1| 800.00|
| Furniture| North| 2000| 2| 1| 2000.00|
| Furniture| South| 1800| 3| 1| 1800.00|
+-----------+------+-------------+--------------+----------------+---------------------+
The order of columns in groupBy() doesn’t affect results, but it does affect the output column order. Choose an order that makes the output readable.
Advanced Aggregations
Beyond basic sums and averages, PySpark offers powerful aggregation functions for complex analysis.
Collecting Values into Lists or Sets
Sometimes you need to aggregate values into collections rather than scalars:
# Collect all revenue values per category
df.groupBy("category").agg(
F.collect_list("revenue").alias("all_revenues"),
F.collect_set("region").alias("unique_regions")
).show(truncate=False)
Output:
+-----------+------------------+--------------+
|category |all_revenues |unique_regions|
+-----------+------------------+--------------+
|Electronics|[1200, 800, 1500] |[South, North]|
|Clothing |[400, 600, 350] |[South, North]|
|Furniture |[2000, 1800] |[South, North]|
+-----------+------------------+--------------+
Use collect_list() when order or duplicates matter; use collect_set() for unique values. Be cautious with large groups—these functions collect all values to a single executor, which can cause memory issues.
Conditional Aggregations
The when() function enables conditional logic within aggregations:
# Conditional aggregations
df.groupBy("category").agg(
F.sum(F.when(F.col("region") == "North", F.col("revenue")).otherwise(0)).alias("north_revenue"),
F.sum(F.when(F.col("region") == "South", F.col("revenue")).otherwise(0)).alias("south_revenue"),
F.count(F.when(F.col("revenue") > 1000, True)).alias("high_value_transactions"),
F.sum(F.when(F.col("quantity") >= 5, F.col("quantity")).otherwise(0)).alias("bulk_order_units")
).show()
Output:
+-----------+-------------+-------------+-----------------------+----------------+
| category|north_revenue|south_revenue|high_value_transactions|bulk_order_units|
+-----------+-------------+-------------+-----------------------+----------------+
|Electronics| 2700| 800| 3| 12|
| Clothing| 400| 950| 0| 33|
| Furniture| 2000| 1800| 2| 0|
+-----------+-------------+-------------+-----------------------+----------------+
This pattern is incredibly useful for creating pivot-table-like summaries without actually pivoting the data.
First and Last Values
When you need representative values from each group:
from pyspark.sql.window import Window
# Get first and last values (requires ordering for deterministic results)
df.orderBy("revenue").groupBy("category").agg(
F.first("revenue").alias("lowest_revenue"),
F.last("revenue").alias("highest_revenue"),
F.first("region").alias("region_of_lowest")
).show()
Note: first() and last() depend on row order. Without explicit ordering, results are non-deterministic.
Performance Considerations
GroupBy operations trigger shuffles—data movement across the cluster. This is often the most expensive part of your job.
Partition Your Data Strategically
If you’re performing multiple groupBy operations on the same key, repartition first:
# Repartition once, benefit multiple times
df_partitioned = df.repartition(200, "category")
# These operations will have less shuffle overhead
stats1 = df_partitioned.groupBy("category").agg(F.sum("revenue"))
stats2 = df_partitioned.groupBy("category").agg(F.avg("quantity"))
The repartition() call causes one shuffle, but subsequent groupBy operations on the same key can leverage the existing partitioning.
Handle Data Skew
Data skew—when one key has dramatically more records than others—kills performance. One partition does all the work while others sit idle.
# Detect skew: check group sizes
df.groupBy("category").count().orderBy(F.desc("count")).show()
# Salting technique for severely skewed keys
from pyspark.sql.functions import rand, concat, lit
# Add salt to break up large groups
num_salts = 10
df_salted = df.withColumn("salted_category",
concat(F.col("category"), lit("_"), (rand() * num_salts).cast("int"))
)
# Aggregate with salt, then aggregate again to combine
partial_agg = df_salted.groupBy("salted_category").agg(
F.sum("revenue").alias("partial_revenue")
)
# Remove salt and final aggregation
final_agg = partial_agg.withColumn("category",
F.split("salted_category", "_")[0]
).groupBy("category").agg(
F.sum("partial_revenue").alias("total_revenue")
)
This two-stage aggregation distributes work more evenly at the cost of an extra shuffle.
Avoid groupByKey in RDD Operations
If you’re working with RDDs (you probably shouldn’t be), prefer reduceByKey over groupByKey. The former combines values locally before shuffling, dramatically reducing data transfer. With DataFrames, Spark’s Catalyst optimizer handles this automatically.
Conclusion
PySpark’s groupBy operations are straightforward once you understand the patterns:
- Use
groupBy().agg()with explicit aliases for production code - Chain multiple aggregations in a single
agg()call for efficiency - Leverage
when()for conditional aggregations - Watch for data skew and consider salting for problematic keys
- Repartition strategically when performing multiple operations on the same grouping key
The official PySpark SQL Functions documentation covers every available aggregate function. Bookmark it—you’ll reference it constantly.