Spark SQL - UDAF (User Defined Aggregate Functions)
User Defined Aggregate Functions process multiple input rows and return a single aggregated result. Unlike UDFs that operate row-by-row, UDAFs maintain internal state across rows within each...
Key Insights
- User Defined Aggregate Functions (UDAFs) in Spark SQL enable custom aggregation logic beyond built-in functions like SUM and AVG, operating on groups of rows to produce single output values
- Spark 3.0+ supports two UDAF implementations: the legacy UserDefinedAggregateFunction (deprecated) and the type-safe Aggregator API with Dataset encoders for better performance and compile-time safety
- UDAFs are essential for domain-specific calculations like geometric mean, percentile approximations, or complex statistical operations where standard SQL aggregates fall short
Understanding UDAF Architecture
User Defined Aggregate Functions process multiple input rows and return a single aggregated result. Unlike UDFs that operate row-by-row, UDAFs maintain internal state across rows within each partition, then combine results across partitions.
The UDAF lifecycle consists of four key phases:
- Initialization: Create a zero value for the buffer
- Update: Process each input row and update the buffer
- Merge: Combine buffers from different partitions
- Evaluate: Convert the final buffer to the output value
This design enables distributed processing where each executor maintains local aggregation state before merging results at the driver.
Implementing UDAFs with the Aggregator API
The modern Aggregator API provides type-safe UDAF implementation. Here’s a geometric mean calculator that demonstrates the complete pattern:
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
case class GeometricMeanBuffer(var product: Double, var count: Long)
object GeometricMean extends Aggregator[Double, GeometricMeanBuffer, Double] {
// Zero value - starting point for aggregation
def zero: GeometricMeanBuffer = GeometricMeanBuffer(1.0, 0L)
// Combine input value with buffer
def reduce(buffer: GeometricMeanBuffer, value: Double): GeometricMeanBuffer = {
buffer.product *= value
buffer.count += 1
buffer
}
// Merge two buffers (from different partitions)
def merge(b1: GeometricMeanBuffer, b2: GeometricMeanBuffer): GeometricMeanBuffer = {
GeometricMeanBuffer(
b1.product * b2.product,
b1.count + b2.count
)
}
// Convert buffer to final output
def finish(buffer: GeometricMeanBuffer): Double = {
if (buffer.count == 0) 0.0
else math.pow(buffer.product, 1.0 / buffer.count)
}
// Encoder for buffer type
def bufferEncoder: Encoder[GeometricMeanBuffer] = Encoders.product
// Encoder for output type
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
// Usage
val spark = SparkSession.builder()
.appName("UDAF Example")
.master("local[*]")
.getOrCreate()
import spark.implicits._
val data = Seq(
("A", 2.0),
("A", 8.0),
("B", 3.0),
("B", 27.0)
).toDF("category", "value")
val geomMean = GeometricMean.toColumn.name("geom_mean")
data.groupBy("category")
.agg(geomMean)
.show()
// Output:
// +--------+------------------+
// |category| geom_mean|
// +--------+------------------+
// | A| 4.0|
// | B| 9.0|
// +--------+------------------+
Handling Complex Input Types
UDAFs can accept multiple input columns using case classes. This weighted average implementation demonstrates multi-column input:
case class WeightedValue(value: Double, weight: Double)
case class WeightedBuffer(var sumProduct: Double, var sumWeight: Double)
object WeightedAverage extends Aggregator[WeightedValue, WeightedBuffer, Double] {
def zero: WeightedBuffer = WeightedBuffer(0.0, 0.0)
def reduce(buffer: WeightedBuffer, input: WeightedValue): WeightedBuffer = {
buffer.sumProduct += input.value * input.weight
buffer.sumWeight += input.weight
buffer
}
def merge(b1: WeightedBuffer, b2: WeightedBuffer): WeightedBuffer = {
WeightedBuffer(
b1.sumProduct + b2.sumProduct,
b1.sumWeight + b2.sumWeight
)
}
def finish(buffer: WeightedBuffer): Double = {
if (buffer.sumWeight == 0.0) 0.0
else buffer.sumProduct / buffer.sumWeight
}
def bufferEncoder: Encoder[WeightedBuffer] = Encoders.product
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
// Usage with struct input
val salesData = Seq(
("Product1", 100.0, 0.3),
("Product1", 150.0, 0.7),
("Product2", 200.0, 0.5),
("Product2", 180.0, 0.5)
).toDF("product", "price", "weight")
val weightedAvg = WeightedAverage.toColumn.name("weighted_avg_price")
salesData
.select($"product", struct($"price", $"weight").as("input"))
.groupBy("product")
.agg(weightedAvg)
.show()
PySpark UDAF Implementation
PySpark uses the pandas UDF approach for aggregate functions, leveraging Apache Arrow for efficient data transfer:
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd
import numpy as np
spark = SparkSession.builder.appName("PySpark UDAF").getOrCreate()
@pandas_udf(DoubleType())
def geometric_mean(values: pd.Series) -> float:
if len(values) == 0:
return 0.0
# Use log to avoid overflow with large products
log_sum = np.sum(np.log(values))
return np.exp(log_sum / len(values))
# Create sample data
data = spark.createDataFrame([
("A", 2.0),
("A", 8.0),
("B", 3.0),
("B", 27.0)
], ["category", "value"])
result = data.groupBy("category").agg(
geometric_mean("value").alias("geom_mean")
)
result.show()
For more complex aggregations requiring state management across groups:
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, StringType, DoubleType
schema = StructType([
StructField("category", StringType()),
StructField("std_dev", DoubleType()),
StructField("variance", DoubleType())
])
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def custom_statistics(pdf):
return pd.DataFrame({
'category': [pdf['category'].iloc[0]],
'std_dev': [pdf['value'].std()],
'variance': [pdf['value'].var()]
})
result = data.groupBy("category").apply(custom_statistics)
result.show()
Performance Considerations
UDAFs introduce overhead compared to built-in functions due to serialization and custom code execution. Optimize performance with these strategies:
// Use primitive types in buffers to minimize serialization
case class OptimizedBuffer(var sum: Double, var count: Long, var sumSquares: Double)
// Avoid object creation in reduce/merge
def reduce(buffer: OptimizedBuffer, value: Double): OptimizedBuffer = {
buffer.sum += value
buffer.count += 1
buffer.sumSquares += value * value
buffer // Reuse existing buffer
}
// Leverage Catalyst optimizer by registering as SQL function
spark.udf.register("geom_mean", GeometricMean)
spark.sql("""
SELECT category, geom_mean(value) as result
FROM data_table
GROUP BY category
""")
Registering UDAFs for SQL Usage
Make UDAFs available in SQL queries through registration:
// Register Aggregator-based UDAF
spark.udf.register("geometric_mean", GeometricMean)
spark.udf.register("weighted_avg", WeightedAverage)
// Create temp view
data.createOrReplaceTempView("metrics")
// Use in SQL
spark.sql("""
SELECT
category,
geometric_mean(value) as geom_mean,
AVG(value) as arithmetic_mean
FROM metrics
GROUP BY category
""").show()
Testing UDAFs
Verify UDAF correctness with unit tests covering edge cases:
import org.scalatest.funsuite.AnyFunSuite
class GeometricMeanTest extends AnyFunSuite {
test("geometric mean of powers of 2") {
val buffer = GeometricMean.zero
val result = GeometricMean.reduce(
GeometricMean.reduce(buffer, 2.0),
8.0
)
assert(GeometricMean.finish(result) == 4.0)
}
test("handles empty input") {
val buffer = GeometricMean.zero
assert(GeometricMean.finish(buffer) == 0.0)
}
test("merge combines buffers correctly") {
val b1 = GeometricMeanBuffer(4.0, 2)
val b2 = GeometricMeanBuffer(9.0, 2)
val merged = GeometricMean.merge(b1, b2)
assert(merged.product == 36.0)
assert(merged.count == 4)
}
}
UDAFs extend Spark SQL’s aggregation capabilities for specialized analytical requirements while maintaining distributed processing semantics. The Aggregator API provides type safety and performance benefits over legacy approaches, making it the preferred implementation method for production systems.