How to Use UDF in PySpark
PySpark's built-in functions cover most data transformation needs, but real-world data is messy. You'll inevitably encounter scenarios where you need custom logic: proprietary business rules, complex...
Key Insights
- UDFs provide flexibility for custom logic but come with significant performance costs—always prefer built-in Spark functions when they exist
- Pandas UDFs (vectorized UDFs) can be 10-100x faster than standard Python UDFs by leveraging Apache Arrow for efficient data transfer
- Always specify return types explicitly and handle null values defensively to avoid runtime errors and silent data corruption
Introduction to PySpark UDFs
PySpark’s built-in functions cover most data transformation needs, but real-world data is messy. You’ll inevitably encounter scenarios where you need custom logic: proprietary business rules, complex string parsing, or domain-specific calculations that don’t map to standard functions.
User Defined Functions (UDFs) let you write arbitrary Python code and apply it to DataFrame columns. They’re the escape hatch when pyspark.sql.functions falls short.
Here’s the catch: UDFs are slow. Every row gets serialized from the JVM to Python, processed, and serialized back. This serialization overhead can make your job 10x slower than an equivalent built-in function. Use UDFs when you must, but understand the cost.
Creating Basic UDFs
There are two ways to create a UDF: using the udf() function or the @udf decorator. Both accomplish the same thing.
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType
spark = SparkSession.builder.appName("UDFDemo").getOrCreate()
# Sample data
data = [("john DOE", "john.doe@email.com"),
("JANE smith", "jane.smith@email.com"),
("Bob Johnson", "bob@email.com")]
df = spark.createDataFrame(data, ["name", "email"])
# Method 1: Using udf() function
def normalize_name(name):
"""Convert name to title case and strip whitespace."""
if name is None:
return None
return name.strip().title()
normalize_name_udf = udf(normalize_name, StringType())
# Method 2: Using @udf decorator
@udf(returnType=StringType())
def extract_username(email):
"""Extract username from email address."""
if email is None:
return None
return email.split("@")[0]
# Apply UDFs
result = df.withColumn("clean_name", normalize_name_udf(col("name"))) \
.withColumn("username", extract_username(col("email")))
result.show()
Output:
+----------+--------------------+----------+----------+
| name| email|clean_name| username|
+----------+--------------------+----------+----------+
| john DOE| john.doe@email.com| John Doe| john.doe|
|JANE smith|jane.smith@email...|Jane Smith|jane.smith|
|Bob Johnson| bob@email.com|Bob Johnson| bob|
+----------+--------------------+----------+----------+
I prefer the udf() function approach for production code. It keeps the function definition separate from the UDF registration, making the code easier to test and reuse.
Specifying Return Types
Always declare your return type explicitly. Spark defaults to StringType(), which forces unnecessary type conversions and can silently corrupt your data.
from pyspark.sql.types import (
StringType, IntegerType, FloatType, BooleanType,
ArrayType, StructType, StructField
)
# Simple types
@udf(returnType=IntegerType())
def count_words(text):
if text is None:
return 0
return len(text.split())
# Array type
@udf(returnType=ArrayType(StringType()))
def split_and_clean(text):
if text is None:
return []
return [word.lower().strip() for word in text.split()]
# Complex struct type
address_schema = StructType([
StructField("street", StringType(), True),
StructField("city", StringType(), True),
StructField("zip_code", StringType(), True)
])
@udf(returnType=address_schema)
def parse_address(address_string):
if address_string is None:
return None
parts = address_string.split(",")
if len(parts) >= 3:
return (parts[0].strip(), parts[1].strip(), parts[2].strip())
return None
# Usage
sample_data = [("123 Main St, Springfield, 12345",),
("456 Oak Ave, Portland, 97201",)]
address_df = spark.createDataFrame(sample_data, ["raw_address"])
address_df.withColumn("parsed", parse_address(col("raw_address"))) \
.select("raw_address", "parsed.*").show(truncate=False)
The wrong return type doesn’t always throw an error—sometimes Spark silently converts values, leading to data quality issues that are hard to debug.
Registering UDFs for SQL Queries
If you prefer writing SQL or need to share UDFs across your team, register them with the Spark session:
from pyspark.sql.types import DoubleType
def calculate_discount(price, discount_percent):
if price is None or discount_percent is None:
return None
return price * (1 - discount_percent / 100)
# Register for SQL use
spark.udf.register("calc_discount", calculate_discount, DoubleType())
# Create sample data
sales_data = [(100.0, 10.0), (250.0, 15.0), (75.0, 5.0)]
sales_df = spark.createDataFrame(sales_data, ["price", "discount_pct"])
sales_df.createOrReplaceTempView("sales")
# Use in SQL query
result = spark.sql("""
SELECT
price,
discount_pct,
calc_discount(price, discount_pct) as final_price
FROM sales
""")
result.show()
Registered UDFs persist for the SparkSession lifetime. In notebooks or long-running applications, this is convenient. In production pipelines, be explicit about registration to avoid confusion.
Pandas UDFs (Vectorized UDFs)
Standard UDFs process one row at a time. Pandas UDFs process entire partitions as Pandas Series, using Apache Arrow for zero-copy data transfer. The performance difference is dramatic.
import pandas as pd
from pyspark.sql.functions import pandas_udf
# Standard UDF (slow)
@udf(returnType=FloatType())
def standard_zscore(value, mean, std):
if value is None or std == 0:
return None
return (value - mean) / std
# Pandas UDF (fast)
@pandas_udf(FloatType())
def vectorized_zscore(values: pd.Series, means: pd.Series, stds: pd.Series) -> pd.Series:
return (values - means) / stds.replace(0, float('nan'))
# Generate test data
import random
test_data = [(random.gauss(100, 15),) for _ in range(100000)]
test_df = spark.createDataFrame(test_data, ["value"])
test_df = test_df.withColumn("mean_val", lit(100.0)) \
.withColumn("std_val", lit(15.0))
# Performance comparison
import time
# Standard UDF timing
start = time.time()
test_df.withColumn("zscore", standard_zscore(col("value"), col("mean_val"), col("std_val"))).count()
standard_time = time.time() - start
# Pandas UDF timing
start = time.time()
test_df.withColumn("zscore", vectorized_zscore(col("value"), col("mean_val"), col("std_val"))).count()
pandas_time = time.time() - start
print(f"Standard UDF: {standard_time:.2f}s")
print(f"Pandas UDF: {pandas_time:.2f}s")
print(f"Speedup: {standard_time / pandas_time:.1f}x")
In my testing, Pandas UDFs typically run 10-50x faster for numerical operations. The speedup varies based on data size and operation complexity, but it’s consistently significant.
Common Pitfalls and Best Practices
Most UDF bugs fall into three categories: null handling, serialization issues, and unnecessary UDF usage.
# BAD: Crashes on null values
@udf(returnType=StringType())
def bad_uppercase(text):
return text.upper() # NoneType has no attribute 'upper'
# GOOD: Null-safe implementation
@udf(returnType=StringType())
def safe_uppercase(text):
if text is None:
return None
return text.upper()
# BETTER: Use built-in function instead
from pyspark.sql.functions import upper
# This is faster AND handles nulls automatically
df.withColumn("upper_name", upper(col("name")))
Serialization issues occur when your UDF references objects that can’t be pickled:
# BAD: Database connection can't be serialized
import psycopg2
conn = psycopg2.connect(...) # This breaks
@udf(returnType=StringType())
def lookup_value(key):
cursor = conn.cursor() # Fails: conn can't be serialized
# ...
# GOOD: Create connection inside UDF or use broadcast variables
@udf(returnType=StringType())
def lookup_value(key):
# Create connection per partition, not per row
# Better: use mapPartitions for connection pooling
pass
# BEST: Broadcast lookup tables for small datasets
lookup_dict = {"A": "Apple", "B": "Banana"}
broadcast_lookup = spark.sparkContext.broadcast(lookup_dict)
@udf(returnType=StringType())
def broadcast_lookup_udf(key):
return broadcast_lookup.value.get(key)
Before writing a UDF, check if a built-in function exists. The pyspark.sql.functions module has over 200 functions. Many “custom” requirements can be solved with combinations of when, regexp_extract, transform, or aggregate.
Performance Optimization Tips
When you must use UDFs, minimize their impact:
# SLOW: Multiple UDF calls
df.withColumn("a", udf1(col("x"))) \
.withColumn("b", udf2(col("x"))) \
.withColumn("c", udf3(col("x")))
# FASTER: Single UDF returning struct
combined_schema = StructType([
StructField("a", StringType()),
StructField("b", IntegerType()),
StructField("c", FloatType())
])
@udf(returnType=combined_schema)
def combined_udf(x):
return (process_a(x), process_b(x), process_c(x))
df.withColumn("results", combined_udf(col("x"))) \
.select("*", "results.*")
For lookup operations, broadcast small datasets instead of using UDFs:
# SLOW: UDF with external lookup
@udf(returnType=StringType())
def lookup_category(code):
categories = {"A": "Electronics", "B": "Clothing"} # Recreated per row!
return categories.get(code)
# FAST: Broadcast join
category_data = [("A", "Electronics"), ("B", "Clothing")]
category_df = spark.createDataFrame(category_data, ["code", "category"])
main_df.join(broadcast(category_df), "code", "left")
The broadcast join avoids Python entirely, keeping all processing in the JVM.
Use Pandas UDFs for any numerical computation. Reserve standard UDFs for operations that genuinely require row-by-row Python logic—complex parsing, external API calls, or business rules that can’t be vectorized.
When profiling, remember that UDF overhead is per-row. A UDF that adds 1ms per row becomes 16 minutes on a million-row dataset. Always test with production-scale data before committing to a UDF-based approach.