PySpark - SQL CASE WHEN Statement
Conditional logic is fundamental to data transformation pipelines. In PySpark, the CASE WHEN statement serves as your primary tool for implementing if-then-else logic at scale across distributed...
Key Insights
- PySpark’s CASE WHEN provides powerful conditional logic for data transformation, available through both SQL syntax and DataFrame API methods (
when().otherwise()), with the DataFrame API generally offering better performance and type safety. - Complex business rules requiring multiple conditions are best expressed by chaining WHEN clauses rather than nesting CASE statements, improving readability and maintainability while avoiding deep nesting anti-patterns.
- CASE WHEN statements significantly outperform user-defined functions (UDFs) for conditional logic because they execute as native Spark operations with Catalyst optimizer support, making them the preferred choice for production workloads.
Introduction to CASE WHEN in PySpark
Conditional logic is fundamental to data transformation pipelines. In PySpark, the CASE WHEN statement serves as your primary tool for implementing if-then-else logic at scale across distributed datasets. Whether you’re categorizing customers, calculating tiered pricing, or handling data quality issues, CASE WHEN gives you the flexibility to transform data based on complex conditions.
Unlike Python’s native conditional statements that operate on single values, PySpark’s CASE WHEN works on entire columns, leveraging Spark’s distributed computing model. This means you can apply sophisticated conditional logic to billions of rows without writing explicit loops or sacrificing performance.
You’ll reach for CASE WHEN when you need to create new columns based on existing values, standardize inconsistent data, implement business rules, or prepare data for analytics. The statement works seamlessly in both SQL queries and DataFrame operations, giving you flexibility in how you express your logic.
Basic CASE WHEN Syntax
PySpark offers two approaches to CASE WHEN logic: SQL-style syntax and DataFrame API methods. Both accomplish the same goal, but the DataFrame API provides better integration with PySpark’s type system and IDE support.
Here’s a straightforward example categorizing people by age group:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("CaseWhenExample").getOrCreate()
# Sample data
data = [
("Alice", 8),
("Bob", 35),
("Charlie", 72),
("Diana", 16),
("Eve", 45)
]
df = spark.createDataFrame(data, ["name", "age"])
# Using DataFrame API
df_with_category = df.withColumn(
"age_group",
F.when(F.col("age") < 13, "child")
.when((F.col("age") >= 13) & (F.col("age") < 65), "adult")
.otherwise("senior")
)
df_with_category.show()
Output:
+-------+---+---------+
| name|age|age_group|
+-------+---+---------+
| Alice| 8| child|
| Bob| 35| adult|
|Charlie| 72| senior|
| Diana| 16| adult|
| Eve| 45| adult|
+-------+---+---------+
The when() function takes a condition and a value to return when that condition is true. Chain multiple when() calls for additional conditions, and finish with otherwise() to handle cases that don’t match any condition. If you omit otherwise(), unmatched rows receive NULL values.
Multiple Conditions with CASE WHEN
Real-world business logic rarely involves simple single-condition checks. You’ll frequently need to evaluate multiple factors simultaneously. Here’s a practical example implementing tiered pricing based on quantity, product category, and customer type:
# Sample sales data
sales_data = [
("laptop", 5, "premium", 1000),
("laptop", 15, "standard", 1000),
("mouse", 100, "premium", 25),
("keyboard", 30, "standard", 75),
("monitor", 8, "premium", 300)
]
sales_df = spark.createDataFrame(
sales_data,
["product", "quantity", "customer_type", "unit_price"]
)
# Complex pricing logic
pricing_df = sales_df.withColumn(
"discount_rate",
F.when(
(F.col("customer_type") == "premium") & (F.col("quantity") >= 10),
0.20
)
.when(
(F.col("customer_type") == "premium") & (F.col("quantity") >= 5),
0.15
)
.when(
(F.col("customer_type") == "standard") & (F.col("quantity") >= 20),
0.10
)
.when(
F.col("quantity") >= 50,
0.05
)
.otherwise(0.0)
).withColumn(
"total_price",
F.col("quantity") * F.col("unit_price") * (1 - F.col("discount_rate"))
)
pricing_df.show()
Notice the order matters: conditions are evaluated sequentially, and the first match wins. Place more specific conditions before general ones to ensure correct evaluation. The laptop order for the premium customer with 15 units gets a 20% discount, not 15%, because the quantity >= 10 condition is checked first.
CASE WHEN with DataFrame API vs SQL
PySpark gives you multiple ways to express the same logic. Understanding when to use each approach helps you write cleaner, more maintainable code.
DataFrame API approach (recommended for most cases):
result_df = df.withColumn(
"price_tier",
F.when(F.col("unit_price") < 50, "budget")
.when(F.col("unit_price") < 200, "mid-range")
.otherwise("premium")
)
SQL expression with expr():
result_df = df.withColumn(
"price_tier",
F.expr("""
CASE
WHEN unit_price < 50 THEN 'budget'
WHEN unit_price < 200 THEN 'mid-range'
ELSE 'premium'
END
""")
)
Full SQL query:
df.createOrReplaceTempView("products")
result_df = spark.sql("""
SELECT *,
CASE
WHEN unit_price < 50 THEN 'budget'
WHEN unit_price < 200 THEN 'mid-range'
ELSE 'premium'
END as price_tier
FROM products
""")
The DataFrame API offers better type checking, IDE autocomplete, and easier testing. Use SQL expressions when you’re migrating existing SQL code or when the SQL syntax is genuinely clearer for complex logic. The expr() approach provides a middle ground, letting you embed SQL snippets within DataFrame operations.
Advanced Use Cases
Handling NULL Values
NULL handling is critical in production data pipelines. CASE WHEN provides explicit control over NULL behavior:
data_with_nulls = [
("Product A", 100, None),
("Product B", None, 50),
("Product C", 80, 20),
("Product D", None, None)
]
df_nulls = spark.createDataFrame(
data_with_nulls,
["product", "stock", "reserved"]
)
df_available = df_nulls.withColumn(
"available_stock",
F.when(F.col("stock").isNull() | F.col("reserved").isNull(), 0)
.otherwise(F.col("stock") - F.col("reserved"))
).withColumn(
"status",
F.when(F.col("stock").isNull(), "unknown")
.when(F.col("available_stock") > 50, "in_stock")
.when(F.col("available_stock") > 0, "low_stock")
.otherwise("out_of_stock")
)
df_available.show()
CASE WHEN in Aggregations
Conditional logic within aggregations enables sophisticated analytical queries:
# Sample transaction data
transactions = [
("2024-01", "electronics", 1500, "completed"),
("2024-01", "electronics", 800, "refunded"),
("2024-01", "clothing", 300, "completed"),
("2024-02", "electronics", 2000, "completed"),
("2024-02", "clothing", 450, "cancelled")
]
trans_df = spark.createDataFrame(
transactions,
["month", "category", "amount", "status"]
)
# Conditional aggregation
summary = trans_df.groupBy("month", "category").agg(
F.sum(
F.when(F.col("status") == "completed", F.col("amount"))
.otherwise(0)
).alias("completed_revenue"),
F.sum(
F.when(F.col("status") == "refunded", F.col("amount"))
.otherwise(0)
).alias("refunded_amount"),
F.count(
F.when(F.col("status") == "completed", 1)
).alias("completed_count")
)
summary.show()
This pattern is invaluable for creating pivot-like summaries without actually pivoting data, maintaining better performance for large datasets.
Nested Conditions for Complex Business Rules
While you should generally avoid deep nesting, sometimes business rules demand it:
# Credit approval logic
applicants = [
(720, 50000, 5000, True),
(680, 75000, 8000, False),
(750, 60000, 3000, True),
(620, 45000, 12000, True)
]
app_df = spark.createDataFrame(
applicants,
["credit_score", "income", "debt", "has_collateral"]
)
approval_df = app_df.withColumn(
"approval_status",
F.when(F.col("credit_score") >= 700,
F.when(F.col("debt") < F.col("income") * 0.3, "approved")
.otherwise("manual_review")
)
.when(F.col("credit_score") >= 650,
F.when((F.col("has_collateral") == True) &
(F.col("debt") < F.col("income") * 0.2), "approved")
.otherwise("manual_review")
)
.otherwise("rejected")
)
approval_df.show()
Performance Considerations and Best Practices
CASE WHEN statements execute as native Spark operations, benefiting from Catalyst optimizer optimizations. This makes them dramatically faster than Python UDFs for conditional logic:
from pyspark.sql.types import StringType
import time
# Create larger dataset for performance testing
large_df = spark.range(0, 1000000).withColumn(
"value", (F.rand() * 100).cast("int")
)
# Method 1: CASE WHEN (fast)
start = time.time()
result1 = large_df.withColumn(
"category",
F.when(F.col("value") < 33, "low")
.when(F.col("value") < 67, "medium")
.otherwise("high")
)
result1.count() # Trigger execution
case_when_time = time.time() - start
# Method 2: UDF (slow)
def categorize_udf(value):
if value < 33:
return "low"
elif value < 67:
return "medium"
else:
return "high"
categorize = F.udf(categorize_udf, StringType())
start = time.time()
result2 = large_df.withColumn("category", categorize(F.col("value")))
result2.count() # Trigger execution
udf_time = time.time() - start
print(f"CASE WHEN: {case_when_time:.2f}s")
print(f"UDF: {udf_time:.2f}s")
print(f"Speedup: {udf_time/case_when_time:.2f}x")
In typical scenarios, CASE WHEN outperforms UDFs by 5-10x or more.
Best Practices:
-
Order conditions from most specific to most general to ensure correct evaluation and potentially short-circuit unnecessary checks.
-
Always include
otherwise()to handle unexpected cases explicitly rather than allowing NULL values that might cause downstream issues. -
Use the DataFrame API over SQL strings when possible for better type safety and refactoring support.
-
Avoid excessive nesting—if you find yourself nesting more than 2-3 levels deep, consider breaking the logic into multiple columns or using a lookup table join.
-
Cache DataFrames when applying multiple CASE WHEN transformations on the same dataset to avoid recomputing the same data.
-
Test edge cases including NULL values, boundary conditions, and unexpected data types.
CASE WHEN is a workhorse of PySpark data transformation. Master it, and you’ll handle the majority of conditional logic requirements in your data pipelines efficiently and elegantly.