PySpark - GroupBy and Max/Min

PySpark's `groupBy()` operation collapses rows into groups and applies aggregate functions like `max()` and `min()`. This is your bread-and-butter operation for answering questions like 'What's the...

Key Insights

  • PySpark’s groupBy() combined with agg() enables efficient max/min calculations across billions of rows through distributed processing, but understanding when to use window functions versus standard aggregations determines whether you retain row-level detail or collapse to group summaries.
  • The most common mistake is using groupBy() when you actually need window functions—if you want to keep all rows while marking which ones contain the max/min values per group, windows are the correct tool.
  • Performance hinges on partition strategy: well-partitioned data on your grouping key can make max/min operations 10x faster by minimizing shuffle operations across your cluster.

Basic GroupBy with Max and Min

PySpark’s groupBy() operation collapses rows into groups and applies aggregate functions like max() and min(). This is your bread-and-butter operation for answering questions like “What’s the highest sale in each region?” or “What’s the minimum temperature recorded per city?”

The syntax is straightforward. You call groupBy() on your DataFrame with the column(s) you want to group by, then chain aggregate functions:

from pyspark.sql import SparkSession
from pyspark.sql.functions import max, min

spark = SparkSession.builder.appName("MaxMinExample").getOrCreate()

# Sample sales data
data = [
    ("North", "2024-01", 15000),
    ("North", "2024-02", 18000),
    ("South", "2024-01", 12000),
    ("South", "2024-02", 14000),
    ("East", "2024-01", 20000),
    ("East", "2024-02", 19000)
]

df = spark.createDataFrame(data, ["region", "month", "sales"])

# Find max and min sales per region
result = df.groupBy("region").agg(
    max("sales").alias("max_sales"),
    min("sales").alias("min_sales")
)

result.show()

Output:

+------+----------+----------+
|region|max_sales |min_sales |
+------+----------+----------+
|North |18000     |15000     |
|South |14000     |12000     |
|East  |20000     |19000     |
+------+----------+----------+

Notice the alias() method. Always alias your aggregated columns. The default names like max(sales) are ugly and harder to work with downstream.

Multiple Column Aggregations

Real-world scenarios rarely involve a single metric. You’ll typically want max, min, average, and count all at once. The agg() method accepts multiple aggregation functions:

from pyspark.sql.functions import max, min, avg, count

# Employee data
employee_data = [
    ("Engineering", "Alice", 95000, 28),
    ("Engineering", "Bob", 87000, 32),
    ("Engineering", "Charlie", 105000, 29),
    ("Sales", "David", 75000, 35),
    ("Sales", "Eve", 82000, 27),
    ("Sales", "Frank", 68000, 41)
]

emp_df = spark.createDataFrame(
    employee_data, 
    ["department", "name", "salary", "age"]
)

# Multiple aggregations at once
dept_stats = emp_df.groupBy("department").agg(
    max("salary").alias("max_salary"),
    min("salary").alias("min_salary"),
    avg("salary").alias("avg_salary"),
    min("age").alias("youngest"),
    max("age").alias("oldest"),
    count("*").alias("employee_count")
)

dept_stats.show()

This produces a comprehensive summary per department in a single pass through the data. In distributed systems, minimizing passes is critical for performance.

Using Window Functions for Max/Min

Here’s where most developers hit a wall. groupBy() collapses your data—you get one row per group. But what if you need to mark which rows contain the max value while keeping all rows?

Window functions solve this. They perform calculations across a set of rows related to the current row without collapsing the dataset:

from pyspark.sql.window import Window
from pyspark.sql.functions import max, col

# Product sales data
product_data = [
    ("Electronics", "Laptop", 1200),
    ("Electronics", "Phone", 800),
    ("Electronics", "Tablet", 600),
    ("Furniture", "Desk", 450),
    ("Furniture", "Chair", 200),
    ("Furniture", "Lamp", 75)
]

product_df = spark.createDataFrame(
    product_data, 
    ["category", "product", "price"]
)

# Define window partitioned by category
window_spec = Window.partitionBy("category")

# Add column showing max price per category
result = product_df.withColumn(
    "max_price_in_category",
    max("price").over(window_spec)
)

result.show()

Output:

+-----------+-------+-----+---------------------+
|category   |product|price|max_price_in_category|
+-----------+-------+-----+---------------------+
|Electronics|Laptop |1200 |1200                 |
|Electronics|Phone  |800  |1200                 |
|Electronics|Tablet |600  |1200                 |
|Furniture  |Desk   |450  |450                  |
|Furniture  |Chair  |200  |450                  |
|Furniture  |Lamp   |75   |450                  |
+-----------+-------+-----+---------------------+

Every row now knows the maximum price in its category. You can then filter to find which products are the most expensive in their category:

top_products = result.filter(col("price") == col("max_price_in_category"))
top_products.show()

Advanced Patterns: Filtering After Aggregation

A common requirement: “Give me the complete record of the top-performing product in each category.” You can’t do this with groupBy() alone because it only returns the max value, not the entire row.

The solution combines window functions with filtering:

from pyspark.sql.functions import row_number

# Extended product data with more attributes
extended_data = [
    ("Electronics", "Laptop", 1200, "Brand A", 4.5),
    ("Electronics", "Phone", 800, "Brand B", 4.7),
    ("Electronics", "Tablet", 1200, "Brand C", 4.3),  # Tie for max
    ("Furniture", "Desk", 450, "Brand D", 4.1),
    ("Furniture", "Chair", 200, "Brand E", 4.6)
]

extended_df = spark.createDataFrame(
    extended_data,
    ["category", "product", "price", "brand", "rating"]
)

# Rank products within each category by price
window_spec = Window.partitionBy("category").orderBy(col("price").desc())

ranked = extended_df.withColumn("rank", row_number().over(window_spec))

# Get top product per category
top_per_category = ranked.filter(col("rank") == 1)
top_per_category.show()

This approach handles ties by taking the first row according to the ordering. If you want all tied records, use rank() or dense_rank() instead of row_number():

from pyspark.sql.functions import rank

# Get all products tied for highest price
window_spec = Window.partitionBy("category").orderBy(col("price").desc())
ranked = extended_df.withColumn("rank", rank().over(window_spec))
all_top_products = ranked.filter(col("rank") == 1)
all_top_products.show()

Performance Considerations

Max/min operations trigger shuffles when your data isn’t already partitioned by your grouping key. Shuffles are expensive—they move data across the network between executors.

Check your execution plan:

df.groupBy("region").agg(max("sales")).explain()

Look for “Exchange” operations in the plan. These indicate shuffles. If you’re repeatedly grouping by the same key, repartition once:

# Repartition by the grouping key
optimized_df = df.repartition("region")

# Now groupBy operations on region are more efficient
result1 = optimized_df.groupBy("region").agg(max("sales"))
result2 = optimized_df.groupBy("region").agg(min("sales"))

For very large datasets, consider caching after expensive operations:

cached_df = df.groupBy("region").agg(max("sales")).cache()
cached_df.count()  # Trigger caching
# Subsequent operations on cached_df are faster

When you only need max/min without additional aggregations, and your data is in RDD form or you’re working with key-value pairs, reduceByKey() can be more efficient than groupBy():

# RDD approach for simple max operation
rdd = df.rdd.map(lambda row: (row.region, row.sales))
max_by_region = rdd.reduceByKey(lambda a, b: max(a, b))

However, DataFrame operations are generally more optimized due to Catalyst optimizer, so stick with DataFrames unless you have a specific reason to use RDDs.

Common Pitfalls and Best Practices

Null handling: Max and min ignore null values, but empty groups return null:

from pyspark.sql.functions import coalesce, lit

# Handle potential nulls in results
safe_result = df.groupBy("region").agg(
    coalesce(max("sales"), lit(0)).alias("max_sales"),
    coalesce(min("sales"), lit(0)).alias("min_sales")
)

Ties for max/min values: As shown earlier, decide whether you want one arbitrary row or all tied rows. Use row_number() for one, rank() for all.

Column name conflicts: When grouping by a column and selecting it again, be explicit:

# Avoid ambiguous column references
result = df.groupBy("region").agg(
    max("sales").alias("max_sales")
).select("region", "max_sales")  # Explicit selection

Memory issues with large groups: If individual groups are massive, consider approximate algorithms:

from pyspark.sql.functions import approx_count_distinct

# For very large groups, approximate aggregations are faster
df.groupBy("region").agg(
    approx_count_distinct("customer_id", 0.05).alias("approx_customers")
)

Testing with small data: Always test aggregation logic on a small sample before running on production data:

# Test on sample
sample_df = df.sample(0.01)  # 1% sample
sample_result = sample_df.groupBy("region").agg(max("sales"))
sample_result.show()

PySpark’s max/min operations are deceptively simple but incredibly powerful when you understand the distinction between collapsing aggregations and window functions. Master both patterns, pay attention to your execution plans, and handle edge cases explicitly. Your data pipelines will be faster, more correct, and easier to maintain.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.