Spark SQL - UDAF (User Defined Aggregate Functions)

User Defined Aggregate Functions process multiple input rows and return a single aggregated result. Unlike UDFs that operate row-by-row, UDAFs maintain internal state across rows within each...

Key Insights

  • User Defined Aggregate Functions (UDAFs) in Spark SQL enable custom aggregation logic beyond built-in functions like SUM and AVG, operating on groups of rows to produce single output values
  • Spark 3.0+ supports two UDAF implementations: the legacy UserDefinedAggregateFunction (deprecated) and the type-safe Aggregator API with Dataset encoders for better performance and compile-time safety
  • UDAFs are essential for domain-specific calculations like geometric mean, percentile approximations, or complex statistical operations where standard SQL aggregates fall short

Understanding UDAF Architecture

User Defined Aggregate Functions process multiple input rows and return a single aggregated result. Unlike UDFs that operate row-by-row, UDAFs maintain internal state across rows within each partition, then combine results across partitions.

The UDAF lifecycle consists of four key phases:

  1. Initialization: Create a zero value for the buffer
  2. Update: Process each input row and update the buffer
  3. Merge: Combine buffers from different partitions
  4. Evaluate: Convert the final buffer to the output value

This design enables distributed processing where each executor maintains local aggregation state before merging results at the driver.

Implementing UDAFs with the Aggregator API

The modern Aggregator API provides type-safe UDAF implementation. Here’s a geometric mean calculator that demonstrates the complete pattern:

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator

case class GeometricMeanBuffer(var product: Double, var count: Long)

object GeometricMean extends Aggregator[Double, GeometricMeanBuffer, Double] {
  
  // Zero value - starting point for aggregation
  def zero: GeometricMeanBuffer = GeometricMeanBuffer(1.0, 0L)
  
  // Combine input value with buffer
  def reduce(buffer: GeometricMeanBuffer, value: Double): GeometricMeanBuffer = {
    buffer.product *= value
    buffer.count += 1
    buffer
  }
  
  // Merge two buffers (from different partitions)
  def merge(b1: GeometricMeanBuffer, b2: GeometricMeanBuffer): GeometricMeanBuffer = {
    GeometricMeanBuffer(
      b1.product * b2.product,
      b1.count + b2.count
    )
  }
  
  // Convert buffer to final output
  def finish(buffer: GeometricMeanBuffer): Double = {
    if (buffer.count == 0) 0.0
    else math.pow(buffer.product, 1.0 / buffer.count)
  }
  
  // Encoder for buffer type
  def bufferEncoder: Encoder[GeometricMeanBuffer] = Encoders.product
  
  // Encoder for output type
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

// Usage
val spark = SparkSession.builder()
  .appName("UDAF Example")
  .master("local[*]")
  .getOrCreate()

import spark.implicits._

val data = Seq(
  ("A", 2.0),
  ("A", 8.0),
  ("B", 3.0),
  ("B", 27.0)
).toDF("category", "value")

val geomMean = GeometricMean.toColumn.name("geom_mean")

data.groupBy("category")
  .agg(geomMean)
  .show()

// Output:
// +--------+------------------+
// |category|         geom_mean|
// +--------+------------------+
// |       A|               4.0|
// |       B|               9.0|
// +--------+------------------+

Handling Complex Input Types

UDAFs can accept multiple input columns using case classes. This weighted average implementation demonstrates multi-column input:

case class WeightedValue(value: Double, weight: Double)
case class WeightedBuffer(var sumProduct: Double, var sumWeight: Double)

object WeightedAverage extends Aggregator[WeightedValue, WeightedBuffer, Double] {
  
  def zero: WeightedBuffer = WeightedBuffer(0.0, 0.0)
  
  def reduce(buffer: WeightedBuffer, input: WeightedValue): WeightedBuffer = {
    buffer.sumProduct += input.value * input.weight
    buffer.sumWeight += input.weight
    buffer
  }
  
  def merge(b1: WeightedBuffer, b2: WeightedBuffer): WeightedBuffer = {
    WeightedBuffer(
      b1.sumProduct + b2.sumProduct,
      b1.sumWeight + b2.sumWeight
    )
  }
  
  def finish(buffer: WeightedBuffer): Double = {
    if (buffer.sumWeight == 0.0) 0.0
    else buffer.sumProduct / buffer.sumWeight
  }
  
  def bufferEncoder: Encoder[WeightedBuffer] = Encoders.product
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

// Usage with struct input
val salesData = Seq(
  ("Product1", 100.0, 0.3),
  ("Product1", 150.0, 0.7),
  ("Product2", 200.0, 0.5),
  ("Product2", 180.0, 0.5)
).toDF("product", "price", "weight")

val weightedAvg = WeightedAverage.toColumn.name("weighted_avg_price")

salesData
  .select($"product", struct($"price", $"weight").as("input"))
  .groupBy("product")
  .agg(weightedAvg)
  .show()

PySpark UDAF Implementation

PySpark uses the pandas UDF approach for aggregate functions, leveraging Apache Arrow for efficient data transfer:

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd
import numpy as np

spark = SparkSession.builder.appName("PySpark UDAF").getOrCreate()

@pandas_udf(DoubleType())
def geometric_mean(values: pd.Series) -> float:
    if len(values) == 0:
        return 0.0
    # Use log to avoid overflow with large products
    log_sum = np.sum(np.log(values))
    return np.exp(log_sum / len(values))

# Create sample data
data = spark.createDataFrame([
    ("A", 2.0),
    ("A", 8.0),
    ("B", 3.0),
    ("B", 27.0)
], ["category", "value"])

result = data.groupBy("category").agg(
    geometric_mean("value").alias("geom_mean")
)
result.show()

For more complex aggregations requiring state management across groups:

from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, StringType, DoubleType

schema = StructType([
    StructField("category", StringType()),
    StructField("std_dev", DoubleType()),
    StructField("variance", DoubleType())
])

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def custom_statistics(pdf):
    return pd.DataFrame({
        'category': [pdf['category'].iloc[0]],
        'std_dev': [pdf['value'].std()],
        'variance': [pdf['value'].var()]
    })

result = data.groupBy("category").apply(custom_statistics)
result.show()

Performance Considerations

UDAFs introduce overhead compared to built-in functions due to serialization and custom code execution. Optimize performance with these strategies:

// Use primitive types in buffers to minimize serialization
case class OptimizedBuffer(var sum: Double, var count: Long, var sumSquares: Double)

// Avoid object creation in reduce/merge
def reduce(buffer: OptimizedBuffer, value: Double): OptimizedBuffer = {
  buffer.sum += value
  buffer.count += 1
  buffer.sumSquares += value * value
  buffer  // Reuse existing buffer
}

// Leverage Catalyst optimizer by registering as SQL function
spark.udf.register("geom_mean", GeometricMean)

spark.sql("""
  SELECT category, geom_mean(value) as result
  FROM data_table
  GROUP BY category
""")

Registering UDAFs for SQL Usage

Make UDAFs available in SQL queries through registration:

// Register Aggregator-based UDAF
spark.udf.register("geometric_mean", GeometricMean)
spark.udf.register("weighted_avg", WeightedAverage)

// Create temp view
data.createOrReplaceTempView("metrics")

// Use in SQL
spark.sql("""
  SELECT 
    category,
    geometric_mean(value) as geom_mean,
    AVG(value) as arithmetic_mean
  FROM metrics
  GROUP BY category
""").show()

Testing UDAFs

Verify UDAF correctness with unit tests covering edge cases:

import org.scalatest.funsuite.AnyFunSuite

class GeometricMeanTest extends AnyFunSuite {
  
  test("geometric mean of powers of 2") {
    val buffer = GeometricMean.zero
    val result = GeometricMean.reduce(
      GeometricMean.reduce(buffer, 2.0),
      8.0
    )
    assert(GeometricMean.finish(result) == 4.0)
  }
  
  test("handles empty input") {
    val buffer = GeometricMean.zero
    assert(GeometricMean.finish(buffer) == 0.0)
  }
  
  test("merge combines buffers correctly") {
    val b1 = GeometricMeanBuffer(4.0, 2)
    val b2 = GeometricMeanBuffer(9.0, 2)
    val merged = GeometricMean.merge(b1, b2)
    assert(merged.product == 36.0)
    assert(merged.count == 4)
  }
}

UDAFs extend Spark SQL’s aggregation capabilities for specialized analytical requirements while maintaining distributed processing semantics. The Aggregator API provides type safety and performance benefits over legacy approaches, making it the preferred implementation method for production systems.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.