PySpark - GroupBy and Average (Mean)

GroupBy operations form the backbone of data aggregation in PySpark, enabling you to collapse millions or billions of rows into meaningful summaries. Unlike pandas where groupBy operations happen...

Key Insights

  • PySpark’s avg() and mean() functions are interchangeable aliases that compute averages across grouped data in distributed environments, with agg() providing the most flexible syntax for complex aggregations
  • Grouping by multiple columns creates hierarchical aggregations that enable multi-dimensional analysis, while proper column aliasing and null handling prevent common data quality issues in production pipelines
  • Performance optimization through strategic repartitioning and caching can reduce execution time by 10x or more on large datasets, especially when the same groupBy operation feeds multiple downstream transformations

Introduction to GroupBy Operations in PySpark

GroupBy operations form the backbone of data aggregation in PySpark, enabling you to collapse millions or billions of rows into meaningful summaries. Unlike pandas where groupBy operations happen in-memory on a single machine, PySpark distributes these calculations across cluster nodes, making it possible to analyze datasets that would crash a traditional setup.

The average (mean) calculation is one of the most common aggregations you’ll perform. Whether you’re analyzing customer spending patterns, sensor readings, employee performance metrics, or financial data, computing averages across categorical groups reveals trends that raw data obscures.

Let’s start with a realistic dataset to work with throughout this article:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.appName("GroupByAverage").getOrCreate()

# Sample sales data
data = [
    ("North", "Electronics", "2024-01", 1200.50),
    ("North", "Electronics", "2024-01", 980.00),
    ("South", "Electronics", "2024-01", 1450.75),
    ("North", "Clothing", "2024-01", 320.00),
    ("South", "Clothing", "2024-01", 290.50),
    ("North", "Electronics", "2024-02", 1100.00),
    ("South", "Electronics", "2024-02", 1380.25),
    ("North", "Clothing", "2024-02", 410.00),
    ("South", "Clothing", "2024-02", 385.75),
]

df = spark.createDataFrame(data, ["region", "category", "month", "sales"])
df.show()

This creates a DataFrame representing sales transactions across regions and product categories—a scenario you’ll encounter constantly in business analytics.

Basic GroupBy with Average Syntax

PySpark offers two identical functions for calculating averages: avg() and mean(). They’re complete aliases with no performance or functional differences. Use whichever reads more naturally to you, but stay consistent within your codebase.

The most straightforward approach groups by a single column and calculates the average of another:

# Calculate average sales by region
avg_by_region = df.groupBy("region").avg("sales")
avg_by_region.show()

# Output:
# +------+------------------+
# |region|        avg(sales)|
# +------+------------------+
# | South|           876.5625|
# | North|           752.125|
# +------+------------------+

The avg() method applied directly after groupBy() works for simple cases, but you’ll quickly outgrow it. The agg() method provides far more flexibility:

# Using agg() with avg() - preferred approach
avg_by_region = df.groupBy("region").agg(F.avg("sales"))
avg_by_region.show()

# Using mean() instead of avg()
avg_by_region = df.groupBy("region").agg(F.mean("sales"))
avg_by_region.show()

The agg() approach becomes essential when you need multiple aggregations or want to rename columns inline. Always import functions as F and use F.avg() rather than the string-based syntax—it provides better IDE support and catches errors at parse time rather than runtime.

Grouping by Multiple Columns

Real-world analysis rarely involves single-dimension grouping. You’ll typically need to slice your data across multiple categorical variables simultaneously:

# Group by region AND category
avg_by_region_category = df.groupBy("region", "category").agg(
    F.avg("sales").alias("avg_sales")
)
avg_by_region_category.show()

# Output:
# +------+-----------+------------------+
# |region|   category|         avg_sales|
# +------+-----------+------------------+
# | South|  Clothing|            338.125|
# | North|Electronics|1093.5000000000002|
# | South|Electronics|          1415.500|
# | North|  Clothing|            365.000|
# +------+-----------+------------------+

Notice the alias() method—it renames the output column from the default avg(sales) to something cleaner. Always alias your aggregation columns in production code.

When you need averages for multiple numeric columns, pass them all to agg():

# Add a quantity column for demonstration
data_multi = [
    ("North", "Electronics", 1200.50, 5),
    ("North", "Electronics", 980.00, 3),
    ("South", "Electronics", 1450.75, 7),
    ("North", "Clothing", 320.00, 12),
]

df_multi = spark.createDataFrame(
    data_multi, ["region", "category", "sales", "quantity"]
)

# Calculate averages for multiple columns
multi_avg = df_multi.groupBy("region", "category").agg(
    F.avg("sales").alias("avg_sales"),
    F.avg("quantity").alias("avg_quantity")
)
multi_avg.show()

For even more concise syntax when averaging many columns, use dictionary notation:

# Dictionary syntax for multiple aggregations
multi_avg = df_multi.groupBy("region", "category").agg(
    {"sales": "avg", "quantity": "avg"}
)
multi_avg.show()

The dictionary approach is cleaner for uniform operations across columns, but you lose the ability to alias individually. Choose based on your specific needs.

Advanced Aggregations and Transformations

Production pipelines rarely involve just averages. You’ll combine multiple aggregation functions to build comprehensive summaries:

# Combine multiple aggregation functions
comprehensive_stats = df.groupBy("region", "category").agg(
    F.avg("sales").alias("avg_sales"),
    F.sum("sales").alias("total_sales"),
    F.count("sales").alias("transaction_count"),
    F.min("sales").alias("min_sales"),
    F.max("sales").alias("max_sales")
)
comprehensive_stats.show()

Averages with many decimal places clutter reports. Round them appropriately:

# Round averages to 2 decimal places
rounded_avg = df.groupBy("region", "category").agg(
    F.round(F.avg("sales"), 2).alias("avg_sales")
)
rounded_avg.show()

Null values require explicit handling. PySpark’s avg() function ignores nulls by default, calculating the mean of non-null values only. If you need different behavior:

# Data with nulls
data_nulls = [
    ("North", "Electronics", 1200.50),
    ("North", "Electronics", None),
    ("North", "Electronics", 980.00),
]

df_nulls = spark.createDataFrame(data_nulls, ["region", "category", "sales"])

# Default behavior - nulls ignored
avg_ignore_null = df_nulls.groupBy("region").agg(F.avg("sales"))
avg_ignore_null.show()  # Average of 1200.50 and 980.00

# Drop rows with nulls before grouping
avg_no_nulls = df_nulls.na.drop().groupBy("region").agg(F.avg("sales"))

# Fill nulls with a value before grouping
avg_filled = df_nulls.na.fill(0).groupBy("region").agg(F.avg("sales"))

Choose your null-handling strategy based on business logic. Ignoring nulls makes sense for optional fields; filling with zero works for truly missing transactions; dropping nulls is appropriate when null indicates invalid data.

Performance Considerations

GroupBy operations trigger shuffles—data movement across the cluster—which can become expensive on large datasets. Understanding and optimizing this is critical for production performance.

Check the execution plan to see what Spark is doing:

# View the physical execution plan
df.groupBy("region").agg(F.avg("sales")).explain()

The output shows shuffle operations. If you’re grouping by a column with high cardinality (many unique values), consider repartitioning first:

# Repartition by the groupBy column before aggregating
optimized = df.repartition("region").groupBy("region").agg(
    F.avg("sales").alias("avg_sales")
)

This co-locates data for the same group on the same partition, reducing shuffle overhead. The benefit increases with dataset size.

For repeated aggregations on the same DataFrame, caching prevents redundant computation:

# Cache when you'll use the same groupBy multiple times
df.cache()

# First aggregation - computes and caches
avg1 = df.groupBy("region").agg(F.avg("sales"))
avg1.show()

# Second aggregation - uses cached data
avg2 = df.groupBy("category").agg(F.avg("sales"))
avg2.show()

# Don't forget to unpersist when done
df.unpersist()

Caching stores the DataFrame in memory across the cluster. Only cache when you’ll reuse the data multiple times—caching itself has overhead.

Practical Use Case Example

Let’s build a complete analytical workflow that calculates regional performance metrics for a retail chain:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

# Initialize Spark
spark = SparkSession.builder.appName("RetailAnalysis").getOrCreate()

# Realistic sales dataset
sales_data = [
    ("North", "Electronics", "Q1", 2500.00, 15),
    ("North", "Electronics", "Q1", 3200.00, 20),
    ("North", "Clothing", "Q1", 450.00, 8),
    ("South", "Electronics", "Q1", 2800.00, 18),
    ("South", "Clothing", "Q1", 520.00, 12),
    ("North", "Electronics", "Q2", 3100.00, 22),
    ("North", "Clothing", "Q2", 480.00, 9),
    ("South", "Electronics", "Q2", 2950.00, 19),
    ("South", "Clothing", "Q2", 510.00, 11),
]

df = spark.createDataFrame(
    sales_data, ["region", "category", "quarter", "revenue", "units_sold"]
)

# Comprehensive analysis: average revenue and units by region and category
analysis = df.groupBy("region", "category").agg(
    F.round(F.avg("revenue"), 2).alias("avg_revenue"),
    F.round(F.avg("units_sold"), 2).alias("avg_units"),
    F.sum("revenue").alias("total_revenue"),
    F.count("*").alias("transaction_count")
).orderBy("region", F.desc("avg_revenue"))

print("Regional Performance Analysis:")
analysis.show()

# Calculate average transaction value (revenue per unit)
detailed_analysis = df.groupBy("region", "category").agg(
    F.round(F.avg("revenue"), 2).alias("avg_revenue"),
    F.round(F.avg("units_sold"), 2).alias("avg_units"),
    F.round(F.avg(F.col("revenue") / F.col("units_sold")), 2).alias("avg_price_per_unit")
)

print("Detailed Pricing Analysis:")
detailed_analysis.show()

This workflow demonstrates production-ready practices: proper aliasing, rounding for readability, combining multiple aggregations, sorting results logically, and calculating derived metrics like per-unit pricing.

GroupBy with average is deceptively simple syntactically but offers immense analytical power. Master the patterns shown here—multi-column grouping, comprehensive aggregations, null handling, and performance optimization—and you’ll handle 90% of real-world aggregation scenarios efficiently. The key is understanding when to use each approach and how Spark executes your operations under the hood.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.