Spark SQL - Aggregate Functions

Spark SQL provides comprehensive aggregate functions that operate on grouped data. The fundamental pattern involves grouping rows by one or more columns and applying aggregate functions to compute...

Key Insights

  • Spark SQL aggregate functions process groups of rows to compute single values, with built-in functions like sum(), avg(), count(), and collect_list() covering most analytical needs while custom UDAFs handle specialized logic.
  • Window functions extend aggregation capabilities by computing values across row partitions without collapsing groups, enabling running totals, rankings, and moving averages within the same result set.
  • Performance optimization requires understanding partition skew, using appropriate aggregation strategies (partial vs. full), and leveraging Catalyst optimizer hints for complex multi-stage aggregations.

Basic Aggregation Operations

Spark SQL provides comprehensive aggregate functions that operate on grouped data. The fundamental pattern involves grouping rows by one or more columns and applying aggregate functions to compute summary statistics.

from pyspark.sql import SparkSession
from pyspark.sql.functions import sum, avg, count, min, max, stddev

spark = SparkSession.builder.appName("aggregations").getOrCreate()

# Sample sales data
data = [
    ("2024-01-15", "Electronics", "Laptop", 1200, 2),
    ("2024-01-15", "Electronics", "Mouse", 25, 10),
    ("2024-01-16", "Clothing", "Shirt", 45, 5),
    ("2024-01-16", "Electronics", "Keyboard", 75, 3),
    ("2024-01-17", "Clothing", "Pants", 60, 4)
]

df = spark.createDataFrame(data, ["date", "category", "product", "price", "quantity"])

# Basic aggregation by category
category_stats = df.groupBy("category").agg(
    sum("price").alias("total_price"),
    avg("price").alias("avg_price"),
    count("product").alias("product_count"),
    min("quantity").alias("min_quantity"),
    max("quantity").alias("max_quantity")
)

category_stats.show()

The groupBy() method creates a grouped DataFrame, and agg() applies multiple aggregate functions simultaneously. This approach is more efficient than running separate aggregations.

Advanced Aggregation Functions

Beyond basic statistics, Spark SQL offers specialized aggregate functions for complex analytical tasks.

from pyspark.sql.functions import (
    collect_list, collect_set, countDistinct, 
    approx_count_distinct, first, last
)

# Collecting values into arrays
product_aggregations = df.groupBy("category").agg(
    collect_list("product").alias("all_products"),
    collect_set("product").alias("unique_products"),
    countDistinct("product").alias("distinct_count"),
    approx_count_distinct("product", 0.05).alias("approx_distinct"),
    first("product").alias("first_product"),
    last("product").alias("last_product")
)

product_aggregations.show(truncate=False)

# Calculate revenue and aggregate
from pyspark.sql.functions import expr

df_with_revenue = df.withColumn("revenue", expr("price * quantity"))

daily_revenue = df_with_revenue.groupBy("date").agg(
    sum("revenue").alias("total_revenue"),
    avg("revenue").alias("avg_revenue"),
    stddev("revenue").alias("revenue_stddev")
)

daily_revenue.orderBy("date").show()

The collect_list() and collect_set() functions are particularly useful for creating array columns from grouped data. Use approx_count_distinct() for large datasets where exact counts aren’t critical—it provides significant performance improvements with controlled error bounds.

Window Functions for Advanced Analytics

Window functions perform calculations across sets of rows related to the current row without collapsing the result set.

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank, dense_rank, lag, lead

# Prepare sample data with timestamps
time_series_data = [
    ("2024-01-01", "ProductA", 100),
    ("2024-01-02", "ProductA", 120),
    ("2024-01-03", "ProductA", 115),
    ("2024-01-01", "ProductB", 200),
    ("2024-01-02", "ProductB", 210),
    ("2024-01-03", "ProductB", 205)
]

ts_df = spark.createDataFrame(time_series_data, ["date", "product", "sales"])

# Define window specifications
product_window = Window.partitionBy("product").orderBy("date")
unbounded_window = Window.partitionBy("product").orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow)

# Apply window functions
result = ts_df.withColumn("row_num", row_number().over(product_window)) \
    .withColumn("sales_rank", rank().over(product_window)) \
    .withColumn("prev_day_sales", lag("sales", 1).over(product_window)) \
    .withColumn("next_day_sales", lead("sales", 1).over(product_window)) \
    .withColumn("running_total", sum("sales").over(unbounded_window)) \
    .withColumn("running_avg", avg("sales").over(unbounded_window))

result.show()

Window specifications define the partition and ordering for calculations. The rowsBetween() method specifies frame boundaries—use unboundedPreceding for cumulative calculations or define sliding windows for moving averages.

Moving Averages and Time-Based Windows

Time-based analytics often require sliding window calculations over specific time ranges.

from pyspark.sql.functions import unix_timestamp, from_unixtime

# Create timestamp column
ts_df_with_ts = ts_df.withColumn("timestamp", unix_timestamp("date", "yyyy-MM-dd"))

# Define range-based window (7 days)
range_window = Window.partitionBy("product") \
    .orderBy("timestamp") \
    .rangeBetween(-7 * 24 * 60 * 60, 0)  # 7 days in seconds

# Calculate moving averages
moving_avg_df = ts_df_with_ts.withColumn(
    "moving_avg_7day", avg("sales").over(range_window)
).withColumn(
    "moving_sum_7day", sum("sales").over(range_window)
)

moving_avg_df.select("date", "product", "sales", "moving_avg_7day").show()

# Row-based sliding window
row_window = Window.partitionBy("product") \
    .orderBy("date") \
    .rowsBetween(-2, 0)  # Current row and 2 preceding

sliding_result = ts_df.withColumn(
    "ma_3period", avg("sales").over(row_window)
)

sliding_result.show()

Range-based windows use actual value ranges (like timestamps), while row-based windows use physical row positions. Choose based on your data distribution requirements.

Custom Aggregate Functions (UDAFs)

For specialized aggregation logic not covered by built-in functions, implement User-Defined Aggregate Functions.

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd

# Pandas UDAF for weighted average
@pandas_udf(DoubleType())
def weighted_average(prices: pd.Series, quantities: pd.Series) -> float:
    total_value = (prices * quantities).sum()
    total_quantity = quantities.sum()
    return total_value / total_quantity if total_quantity > 0 else 0.0

# Apply custom aggregation
weighted_avg_result = df.groupBy("category").agg(
    weighted_average("price", "quantity").alias("weighted_avg_price")
)

weighted_avg_result.show()

# More complex UDAF with multiple outputs
from pyspark.sql.types import StructType, StructField

schema = StructType([
    StructField("total", DoubleType()),
    StructField("count", DoubleType()),
    StructField("ratio", DoubleType())
])

@pandas_udf(schema)
def custom_stats(prices: pd.Series, quantities: pd.Series) -> pd.DataFrame:
    total = (prices * quantities).sum()
    count = len(prices)
    ratio = total / count if count > 0 else 0
    return pd.DataFrame([[total, count, ratio]])

complex_result = df.groupBy("category").agg(
    custom_stats("price", "quantity").alias("stats")
)

complex_result.select("category", "stats.*").show()

Pandas UDAFs leverage vectorized operations for better performance compared to row-at-a-time UDFs. They’re particularly effective for complex mathematical operations.

Performance Optimization Strategies

Aggregation performance depends on data distribution, partition strategy, and execution plan optimization.

# Check partition distribution
df.groupBy("category").count().show()

# Repartition before aggregation if skewed
balanced_df = df.repartition("category")

# Use broadcast for small dimension tables
from pyspark.sql.functions import broadcast

large_fact = spark.range(1000000).withColumn("category_id", expr("id % 100"))
small_dim = spark.range(100).withColumnRenamed("id", "category_id")

# Broadcast join before aggregation
result = large_fact.join(broadcast(small_dim), "category_id") \
    .groupBy("category_id") \
    .agg(count("*").alias("count"))

# Enable adaptive query execution
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

# Partial aggregation hint for high cardinality
high_cardinality_agg = df.hint("partial_aggregate") \
    .groupBy("product") \
    .agg(sum("quantity"))

Monitor query plans using explain() to identify bottlenecks. Partition skew causes stragglers—use salting techniques or adaptive query execution to mitigate. For very large aggregations, consider pre-aggregating data at ingestion time or using incremental aggregation patterns.

The Catalyst optimizer automatically applies partial aggregation for most operations, but explicit hints help with edge cases involving high cardinality keys or complex multi-stage aggregations.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.