Spark SQL - UDF (User Defined Functions) Guide

User Defined Functions in Spark SQL allow you to extend Spark's built-in functionality with custom logic. However, they come with significant trade-offs. When you use a UDF, Spark's Catalyst...

Key Insights

  • UDFs in Spark SQL bridge the gap between built-in functions and custom business logic, but they break Catalyst optimizer optimizations and should be used judiciously
  • Scala and Python UDFs have dramatically different performance characteristics—Scala UDFs execute in the JVM while Python UDFs require serialization overhead that can degrade performance by 10x or more
  • Pandas UDFs (vectorized UDFs) offer near-native performance for Python by processing data in batches using Apache Arrow, making them the preferred choice for production Python workloads

Understanding UDF Performance Implications

User Defined Functions in Spark SQL allow you to extend Spark’s built-in functionality with custom logic. However, they come with significant trade-offs. When you use a UDF, Spark’s Catalyst optimizer treats it as a black box and cannot apply its optimization strategies like predicate pushdown or column pruning.

Here’s a simple example demonstrating the difference:

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col
from pyspark.sql.types import IntegerType

spark = SparkSession.builder.appName("UDFExample").getOrCreate()

# Sample data
data = [(1, 100), (2, 200), (3, 300)]
df = spark.createDataFrame(data, ["id", "value"])

# Built-in function (optimized)
df_builtin = df.withColumn("doubled", col("value") * 2)

# UDF approach (not optimized)
@udf(returnType=IntegerType())
def double_value(x):
    return x * 2

df_udf = df.withColumn("doubled", double_value(col("value")))

The built-in approach generates optimized physical plans, while the UDF version requires row-by-row processing with serialization overhead.

Creating Scala UDFs

Scala UDFs execute directly in the JVM, avoiding serialization costs. They’re the preferred choice when performance is critical.

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.udf

val spark = SparkSession.builder()
  .appName("ScalaUDFExample")
  .getOrCreate()

import spark.implicits._

// Define a simple UDF
val calculateTax = udf((amount: Double, rate: Double) => amount * rate)

// Create sample data
val transactions = Seq(
  (1, "Product A", 100.0),
  (2, "Product B", 250.0),
  (3, "Product C", 75.0)
).toDF("id", "product", "amount")

// Apply UDF
val withTax = transactions.withColumn("tax", calculateTax($"amount", lit(0.08)))
withTax.show()

// Register for SQL usage
spark.udf.register("calculateTax", (amount: Double, rate: Double) => amount * rate)

transactions.createOrReplaceTempView("transactions")
spark.sql("""
  SELECT id, product, amount, calculateTax(amount, 0.08) as tax
  FROM transactions
""").show()

For complex types, specify the return type explicitly:

import org.apache.spark.sql.types._

case class ParsedData(category: String, value: Double)

val parseComplexString = udf((input: String) => {
  val parts = input.split(":")
  ParsedData(parts(0), parts(1).toDouble)
})

// With explicit schema
val schema = StructType(Seq(
  StructField("category", StringType, false),
  StructField("value", DoubleType, false)
))

val parseWithSchema = udf((input: String) => {
  val parts = input.split(":")
  (parts(0), parts(1).toDouble)
}, schema)

Python UDFs and Their Limitations

Python UDFs require serialization between the JVM and Python processes, creating substantial overhead. Use them only when necessary.

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, StructType, StructField, DoubleType
import re

# Simple string transformation
@udf(returnType=StringType())
def clean_email(email):
    if email:
        return email.lower().strip()
    return None

# Complex type return
schema = StructType([
    StructField("first_name", StringType(), True),
    StructField("last_name", StringType(), True)
])

@udf(returnType=schema)
def parse_name(full_name):
    if not full_name:
        return None
    parts = full_name.split()
    if len(parts) >= 2:
        return (parts[0], parts[-1])
    return (full_name, "")

# Usage
users = spark.createDataFrame([
    (1, "John.Doe@EXAMPLE.COM", "John Doe"),
    (2, "Jane.Smith@TEST.com", "Jane Smith")
], ["id", "email", "name"])

cleaned = users.withColumn("clean_email", clean_email(col("email"))) \
               .withColumn("parsed_name", parse_name(col("name")))

cleaned.select("id", "clean_email", "parsed_name.*").show(truncate=False)

Pandas UDFs for Vectorized Operations

Pandas UDFs (also called vectorized UDFs) process data in batches using Apache Arrow, dramatically improving performance over standard Python UDFs.

from pyspark.sql.functions import pandas_udf
import pandas as pd
import numpy as np

# Series to Series transformation
@pandas_udf(DoubleType())
def calculate_discount(prices: pd.Series, quantities: pd.Series) -> pd.Series:
    # Vectorized operations on entire batches
    discount_rate = np.where(quantities > 10, 0.15, 0.05)
    return prices * discount_rate

# DataFrame to DataFrame (for multiple outputs)
schema = StructType([
    StructField("total", DoubleType(), True),
    StructField("discount", DoubleType(), True),
    StructField("final_amount", DoubleType(), True)
])

@pandas_udf(schema)
def calculate_pricing(pdf: pd.DataFrame) -> pd.DataFrame:
    total = pdf['price'] * pdf['quantity']
    discount = np.where(pdf['quantity'] > 10, total * 0.15, total * 0.05)
    final = total - discount
    
    return pd.DataFrame({
        'total': total,
        'discount': discount,
        'final_amount': final
    })

# Sample data
sales = spark.createDataFrame([
    (1, 100.0, 5),
    (2, 50.0, 15),
    (3, 200.0, 8),
    (4, 75.0, 20)
], ["id", "price", "quantity"])

# Apply Pandas UDF
result = sales.withColumn("discount_amount", 
                         calculate_discount(col("price"), col("quantity")))

# Or use the DataFrame version
result2 = sales.select("id", calculate_pricing(struct("price", "quantity")).alias("pricing")) \
               .select("id", "pricing.*")

result2.show()

Grouped Aggregation with Pandas UDFs

Pandas UDFs shine in grouped aggregations, where they can replace complex window functions or custom aggregations.

from pyspark.sql.functions import pandas_udf, PandasUDFType

# Grouped aggregate UDF
@pandas_udf(DoubleType())
def weighted_average(values: pd.Series, weights: pd.Series) -> float:
    return (values * weights).sum() / weights.sum()

# Grouped map UDF - transforms entire groups
schema = StructType([
    StructField("category", StringType(), True),
    StructField("date", StringType(), True),
    StructField("normalized_value", DoubleType(), True)
])

@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def normalize_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # Normalize values within each group
    mean = pdf['value'].mean()
    std = pdf['value'].std()
    
    return pd.DataFrame({
        'category': pdf['category'],
        'date': pdf['date'],
        'normalized_value': (pdf['value'] - mean) / std if std > 0 else 0
    })

# Sample time series data
data = [
    ("A", "2024-01-01", 100.0, 1.0),
    ("A", "2024-01-02", 150.0, 1.5),
    ("B", "2024-01-01", 200.0, 2.0),
    ("B", "2024-01-02", 250.0, 2.5)
]
df = spark.createDataFrame(data, ["category", "date", "value", "weight"])

# Grouped aggregation
weighted_avg = df.groupBy("category").agg(
    weighted_average(col("value"), col("weight")).alias("weighted_avg")
)

# Grouped transformation
normalized = df.groupBy("category").apply(normalize_by_group)
normalized.show()

Best Practices and Optimization

Always benchmark UDFs against built-in functions. Prefer built-in functions, then Pandas UDFs, then Scala UDFs, and use Python UDFs only as a last resort.

# Anti-pattern: Using UDF for simple operations
@udf(returnType=StringType())
def concatenate(a, b):
    return f"{a}_{b}"

# Better: Use built-in functions
from pyspark.sql.functions import concat_ws
df.withColumn("combined", concat_ws("_", col("a"), col("b")))

# Cache UDF results when reusing
expensive_udf_result = df.withColumn("computed", expensive_udf(col("input")))
expensive_udf_result.cache()

# Use broadcast variables for lookup tables
from pyspark.sql.functions import broadcast

lookup_table = spark.createDataFrame(lookup_data)
result = df.join(broadcast(lookup_table), "key")

For null handling in UDFs:

@pandas_udf(DoubleType())
def safe_divide(numerator: pd.Series, denominator: pd.Series) -> pd.Series:
    # Pandas handles nulls naturally
    return numerator / denominator.replace(0, np.nan)

# In regular Python UDFs, check explicitly
@udf(returnType=DoubleType())
def safe_divide_python(num, den):
    if num is None or den is None or den == 0:
        return None
    return num / den

UDFs are powerful but expensive. Exhaust built-in options first, profile your code, and choose the UDF type that matches your performance requirements and team’s skill set.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.