PySpark - Aggregate Functions (sum, avg, max, min, count)

Aggregate functions are fundamental operations in any data processing framework. In PySpark, these functions enable you to summarize, analyze, and extract insights from massive datasets distributed...

Key Insights

  • PySpark aggregate functions operate on distributed DataFrames and scale to billions of rows, unlike pandas which loads everything into memory—use PySpark when your data exceeds single-machine capacity or when you need fault-tolerant processing across a cluster.
  • The groupBy() method returns a GroupedData object that requires an aggregation function to execute, and you can chain multiple aggregates in a single pass to avoid redundant data shuffles across the cluster.
  • Partition-aware aggregations dramatically improve performance—understanding how PySpark shuffles data during groupBy operations is critical for writing efficient aggregation queries that don’t bottleneck your pipeline.

Introduction to PySpark Aggregations

Aggregate functions are fundamental operations in any data processing framework. In PySpark, these functions enable you to summarize, analyze, and extract insights from massive datasets distributed across cluster nodes. Unlike pandas, which processes data in-memory on a single machine, PySpark aggregations leverage distributed computing to handle datasets that would crash a traditional Python environment.

You should reach for PySpark aggregations when dealing with datasets exceeding 10GB, when you need horizontal scalability, or when your data pipeline already runs on Spark infrastructure. The syntax mirrors SQL aggregations, making it familiar to data engineers and analysts.

Let’s create a sample dataset to work with throughout this article:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, avg, max, min, count
from datetime import datetime, timedelta

spark = SparkSession.builder.appName("AggregationExample").getOrCreate()

# Create sample sales data
data = [
    ("2024-01-01", "Electronics", "Laptop", 1200.00, 2, "North"),
    ("2024-01-01", "Electronics", "Mouse", 25.00, 5, "North"),
    ("2024-01-02", "Clothing", "Shirt", 45.00, 3, "South"),
    ("2024-01-02", "Electronics", "Keyboard", 75.00, 4, "East"),
    ("2024-01-03", "Clothing", "Pants", 65.00, 2, "North"),
    ("2024-01-03", "Electronics", "Monitor", 300.00, 1, "West"),
    ("2024-01-04", "Clothing", "Jacket", 120.00, 1, "South"),
    ("2024-01-04", "Electronics", "Laptop", 1200.00, 1, "East"),
]

columns = ["date", "category", "product", "price", "quantity", "region"]
df = spark.createDataFrame(data, columns)
df.show()

Basic Aggregate Functions

PySpark provides five core aggregate functions that mirror standard SQL operations. These functions work on entire DataFrames or grouped subsets of data.

The simplest aggregation applies to the entire DataFrame using the agg() method:

# Single aggregation
total_revenue = df.agg(sum(col("price") * col("quantity"))).collect()[0][0]
print(f"Total Revenue: ${total_revenue}")

# Multiple aggregations in one operation
summary = df.agg(
    sum(col("price") * col("quantity")).alias("total_revenue"),
    avg("price").alias("avg_price"),
    max("quantity").alias("max_quantity"),
    min("price").alias("min_price"),
    count("*").alias("transaction_count")
)
summary.show()

The alias() method is critical for readability—without it, PySpark generates column names like sum((price * quantity)), which becomes unwieldy in production code. Always alias your aggregated columns.

You can also use shorthand methods directly on DataFrame columns:

# Alternative syntax - less flexible but more concise
df.select(
    sum(col("price")).alias("total_price"),
    count("*").alias("row_count")
).show()

The agg() approach is more flexible and composable, especially when combining multiple aggregations or working with grouped data.

GroupBy Aggregations

The real power of aggregations emerges when combined with groupBy(). This operation partitions your DataFrame by specified columns, then applies aggregate functions to each partition.

# Single column groupBy
category_sales = df.groupBy("category").agg(
    sum(col("price") * col("quantity")).alias("total_sales"),
    avg("price").alias("avg_price"),
    count("*").alias("num_transactions")
)
category_sales.show()

Output shows aggregated metrics per category—Electronics vs. Clothing in our example. The groupBy() method returns a GroupedData object, not a DataFrame. You must call an aggregation function to trigger execution and return a DataFrame.

Multi-column grouping enables more granular analysis:

# Multi-column groupBy
regional_category_sales = df.groupBy("region", "category").agg(
    sum(col("price") * col("quantity")).alias("revenue"),
    count("product").alias("product_count")
).orderBy("region", "category")
regional_category_sales.show()

You can chain multiple aggregations efficiently:

# Complex aggregation with multiple metrics
detailed_analysis = df.groupBy("category").agg(
    sum(col("price") * col("quantity")).alias("revenue"),
    avg(col("price") * col("quantity")).alias("avg_transaction_value"),
    max("price").alias("highest_price"),
    min("price").alias("lowest_price"),
    count("*").alias("transaction_count"),
    sum("quantity").alias("total_units_sold")
)
detailed_analysis.show()

This single operation computes six different metrics in one pass over the data, which is far more efficient than running six separate aggregations.

Advanced Aggregation Techniques

Beyond basic groupBy operations, PySpark offers sophisticated aggregation patterns for complex analytical queries.

Pivot tables transform row data into columnar format, useful for reporting:

# Pivot table: regions as rows, categories as columns
pivot_sales = df.groupBy("region").pivot("category").agg(
    sum(col("price") * col("quantity"))
)
pivot_sales.show()

This creates a matrix view where each region shows sales broken down by category in separate columns.

Rollup creates hierarchical aggregations with subtotals:

from pyspark.sql.functions import rollup

# Rollup provides hierarchical aggregations
rollup_result = df.rollup("region", "category").agg(
    sum(col("price") * col("quantity")).alias("total_sales")
).orderBy("region", "category")
rollup_result.show()

Rollup generates aggregations at multiple levels: total across all data, totals per region, and totals per region-category combination. Rows with null values represent higher-level aggregations.

Window functions enable running aggregations without collapsing rows:

from pyspark.sql.window import Window

# Running total using window functions
window_spec = Window.partitionBy("category").orderBy("date")

df_with_running_total = df.withColumn(
    "running_revenue",
    sum(col("price") * col("quantity")).over(window_spec)
)
df_with_running_total.select("date", "category", "product", "running_revenue").show()

Window functions maintain row-level detail while computing aggregates, perfect for time-series analysis and ranking operations.

Performance Considerations & Best Practices

Aggregation performance in PySpark depends heavily on data distribution and shuffle operations. When you call groupBy(), Spark shuffles data across executors to co-locate rows with the same key—this is expensive.

Partition awareness is crucial:

# Repartition before aggregation for better performance
df_repartitioned = df.repartition("category")
optimized_agg = df_repartitioned.groupBy("category").agg(
    sum(col("price") * col("quantity")).alias("total_sales")
)

Repartitioning by your groupBy key eliminates shuffle overhead during aggregation. Use this when you’ll perform multiple operations on the same grouping.

Examine execution plans to understand performance:

# Check physical plan
category_sales.explain(mode="formatted")

The explain() output reveals shuffle operations, partition counts, and optimization opportunities. Look for “Exchange” operations—these indicate data shuffles that may bottleneck performance.

Caching strategies matter for iterative aggregations:

# Cache when reusing aggregations
df.cache()

# Multiple aggregations on cached data
agg1 = df.groupBy("category").agg(sum("quantity"))
agg2 = df.groupBy("region").agg(avg("price"))

# Don't forget to unpersist when done
df.unpersist()

Cache the source DataFrame, not the aggregated result, when you’ll compute multiple different aggregations on the same base data.

SQL vs DataFrame API: Both produce identical execution plans, so choose based on team preference:

# DataFrame API
df_api_result = df.groupBy("category").agg(sum("quantity").alias("total"))

# SQL approach (register temp view first)
df.createOrReplaceTempView("sales")
sql_result = spark.sql("""
    SELECT category, SUM(quantity) as total
    FROM sales
    GROUP BY category
""")

Common Pitfalls and Solutions

Null handling can produce unexpected results:

from pyspark.sql.functions import when, isnan, coalesce

# Create data with nulls
data_with_nulls = [
    ("A", 100, None),
    ("A", 200, 50),
    ("B", None, 30),
    ("B", 150, None),
]
df_nulls = spark.createDataFrame(data_with_nulls, ["category", "value1", "value2"])

# Aggregations ignore nulls by default
df_nulls.groupBy("category").agg(
    sum("value1").alias("sum_value1"),  # Nulls ignored
    count("value2").alias("count_value2")  # Counts only non-null
).show()

# Handle nulls explicitly
df_nulls.groupBy("category").agg(
    sum(coalesce(col("value1"), lit(0))).alias("sum_with_zeros"),
    count(when(col("value2").isNull(), 1)).alias("null_count")
).show()

Type casting prevents aggregation errors:

from pyspark.sql.types import DoubleType

# Ensure numeric types before aggregation
df_typed = df.withColumn("price", col("price").cast(DoubleType()))
df_typed.groupBy("category").agg(avg("price")).show()

Memory issues with high-cardinality groupBy operations require different strategies:

# For very large group counts, consider approximate aggregations
from pyspark.sql.functions import approx_count_distinct

# Exact count (expensive)
exact = df.agg(count(col("product").distinct()))

# Approximate count (much faster, ~2% error)
approx = df.agg(approx_count_distinct("product", 0.02))

PySpark aggregate functions provide the foundation for scalable data analysis. Master these patterns, understand their performance characteristics, and you’ll build efficient data pipelines that scale from gigabytes to petabytes. The key is thinking in terms of distributed operations rather than in-memory processing—let Spark handle the complexity while you focus on the logic.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.