Spark Scala - Window Functions
Window functions solve a fundamental problem in data processing: how do you compute values across multiple rows while keeping each row intact? Standard aggregations with `GROUP BY` collapse rows into...
Key Insights
- Window functions let you compute aggregates and rankings across related rows without collapsing your data—you keep every row while adding computed columns based on a “window” of surrounding rows.
- The three pillars of window specifications—partitioning, ordering, and frame bounds—give you precise control over which rows participate in each calculation.
- Window functions often replace expensive self-joins and subqueries, leading to cleaner code and better performance when used correctly.
Introduction to Window Functions
Window functions solve a fundamental problem in data processing: how do you compute values across multiple rows while keeping each row intact? Standard aggregations with GROUP BY collapse rows into summary statistics. Window functions perform the same calculations but attach results back to individual rows.
Consider a sales dataset where you need both the individual sale amount and the total sales for that product category. With GROUP BY, you’d need a separate aggregation query and then join it back. With window functions, it’s a single pass.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val spark = SparkSession.builder()
.appName("WindowFunctions")
.master("local[*]")
.getOrCreate()
import spark.implicits._
val sales = Seq(
("Electronics", "Laptop", 1200),
("Electronics", "Phone", 800),
("Electronics", "Tablet", 500),
("Clothing", "Jacket", 150),
("Clothing", "Shirt", 50)
).toDF("category", "product", "amount")
// GROUP BY approach - loses individual rows
val groupedTotals = sales.groupBy("category").agg(sum("amount").as("category_total"))
// Window function approach - keeps all rows
val windowSpec = Window.partitionBy("category")
val withWindowTotal = sales.withColumn("category_total", sum("amount").over(windowSpec))
withWindowTotal.show()
// +------------+-------+------+--------------+
// | category|product|amount|category_total|
// +------------+-------+------+--------------+
// | Clothing| Jacket| 150| 200|
// | Clothing| Shirt| 50| 200|
// | Electronics| Laptop| 1200| 2500|
// | Electronics| Phone| 800| 2500|
// | Electronics| Tablet| 500| 2500|
// +------------+-------+------+--------------+
The window function preserves every row while computing the aggregate. This is the core value proposition.
Window Specification Fundamentals
Every window function operates over a WindowSpec that defines three things: which rows to group together (partitioning), how to order those rows (ordering), and which subset of rows to include in each calculation (frame).
val fullWindowSpec = Window
.partitionBy("category") // Group rows by category
.orderBy("amount".desc) // Order within each partition
.rowsBetween(Window.unboundedPreceding, Window.currentRow) // Frame bounds
// Partitioning alone (for simple aggregates)
val partitionOnly = Window.partitionBy("category")
// Partitioning with ordering (required for ranking functions)
val partitionAndOrder = Window.partitionBy("category").orderBy("amount".desc)
Partitioning works like GROUP BY—it defines independent groups. Ordering determines the sequence within each partition, which matters for rankings and running calculations. Frame specification controls exactly which rows relative to the current row participate in aggregate calculations.
Not all combinations are valid. Ranking functions require ordering but ignore frame specifications. Aggregate functions use frames but don’t strictly require ordering (though they usually need it for meaningful results like running totals).
Ranking Functions
Ranking functions assign positions to rows within each partition. The four main functions handle ties differently:
row_number(): Sequential integers, no gaps, arbitrary tie-breakingrank(): Same rank for ties, gaps after ties (1, 2, 2, 4)dense_rank(): Same rank for ties, no gaps (1, 2, 2, 3)ntile(n): Distributes rows into n roughly equal buckets
val products = Seq(
("Electronics", "Laptop", 15000),
("Electronics", "Phone", 12000),
("Electronics", "Tablet", 12000),
("Electronics", "Headphones", 8000),
("Clothing", "Jacket", 5000),
("Clothing", "Shoes", 4500),
("Clothing", "Shirt", 3000),
("Clothing", "Hat", 1500)
).toDF("category", "product", "sales")
val rankWindow = Window.partitionBy("category").orderBy($"sales".desc)
val ranked = products
.withColumn("row_num", row_number().over(rankWindow))
.withColumn("rank", rank().over(rankWindow))
.withColumn("dense_rank", dense_rank().over(rankWindow))
// Get top 3 products per category
val top3PerCategory = ranked.filter($"row_num" <= 3)
top3PerCategory.show()
// +------------+-------+-----+-------+----+----------+
// | category|product|sales|row_num|rank|dense_rank|
// +------------+-------+-----+-------+----+----------+
// | Clothing| Jacket| 5000| 1| 1| 1|
// | Clothing| Shoes| 4500| 2| 2| 2|
// | Clothing| Shirt| 3000| 3| 3| 3|
// | Electronics| Laptop|15000| 1| 1| 1|
// | Electronics| Phone|12000| 2| 2| 2|
// | Electronics| Tablet|12000| 3| 2| 2|
// +------------+-------+-----+-------+----+----------+
Notice how Phone and Tablet both have rank 2 (tied sales), but different row numbers. Use row_number() for deduplication when you need exactly one row per group. Use rank() or dense_rank() when ties should be preserved.
Analytic Functions
Analytic functions access values from other rows relative to the current row. These are invaluable for time-series analysis and sequential comparisons.
val dailySales = Seq(
("2024-01-01", 1000),
("2024-01-02", 1200),
("2024-01-03", 950),
("2024-01-04", 1100),
("2024-01-05", 1350),
("2024-01-06", 1400),
("2024-01-07", 1250)
).toDF("date", "sales")
val timeWindow = Window.orderBy("date")
val withAnalytics = dailySales
.withColumn("prev_day_sales", lag("sales", 1).over(timeWindow))
.withColumn("next_day_sales", lead("sales", 1).over(timeWindow))
.withColumn("daily_change", $"sales" - lag("sales", 1).over(timeWindow))
.withColumn("pct_change",
round(($"sales" - lag("sales", 1).over(timeWindow)) /
lag("sales", 1).over(timeWindow) * 100, 2))
withAnalytics.show()
// +----------+-----+--------------+--------------+------------+----------+
// | date|sales|prev_day_sales|next_day_sales|daily_change|pct_change|
// +----------+-----+--------------+--------------+------------+----------+
// |2024-01-01| 1000| null| 1200| null| null|
// |2024-01-02| 1200| 1000| 950| 200| 20.0|
// |2024-01-03| 950| 1200| 1100| -250| -20.83|
// |2024-01-04| 1100| 950| 1350| 150| 15.79|
// |2024-01-05| 1350| 1100| 1400| 250| 22.73|
// |2024-01-06| 1400| 1350| 1250| 50| 3.70|
// |2024-01-07| 1250| 1400| null| -150| -10.71|
// +----------+-----+--------------+--------------+------------+----------+
The lag() function looks backward, lead() looks forward. Both accept an offset parameter (default 1) and an optional default value for edge cases. first_value() and last_value() grab the first or last value in the window frame—useful for carrying forward reference values.
Aggregate Functions Over Windows
Standard aggregates become powerful when combined with window specifications. Running totals, moving averages, and cumulative statistics fall out naturally.
val orderedWindow = Window.orderBy("date")
val movingWindow = Window.orderBy("date").rowsBetween(-2, 0) // Current + 2 preceding
val withAggregates = dailySales
.withColumn("running_total", sum("sales").over(orderedWindow))
.withColumn("cumulative_avg", round(avg("sales").over(orderedWindow), 2))
.withColumn("3day_moving_avg", round(avg("sales").over(movingWindow), 2))
.withColumn("running_max", max("sales").over(orderedWindow))
withAggregates.show()
// +----------+-----+-------------+--------------+---------------+-----------+
// | date|sales|running_total|cumulative_avg|3day_moving_avg|running_max|
// +----------+-----+-------------+--------------+---------------+-----------+
// |2024-01-01| 1000| 1000| 1000.0| 1000.0| 1000|
// |2024-01-02| 1200| 2200| 1100.0| 1100.0| 1200|
// |2024-01-03| 950| 3150| 1050.0| 1050.0 | 1200|
// |2024-01-04| 1100| 4250| 1062.50| 1083.33| 1200|
// |2024-01-05| 1350| 5600| 1120.0| 1133.33| 1350|
// |2024-01-06| 1400| 7000| 1166.67| 1283.33| 1400|
// |2024-01-07| 1250| 8250| 1178.57| 1333.33| 1400|
// +----------+-----+-------------+--------------+---------------+-----------+
The default frame for ordered windows is unboundedPreceding to currentRow, which gives you running totals. Explicit frame bounds like rowsBetween(-2, 0) create sliding windows for moving averages.
Frame Specifications Deep Dive
Frame specifications control exactly which rows participate in aggregate calculations. The distinction between rowsBetween and rangeBetween trips up many developers.
rowsBetween counts physical row positions. rangeBetween uses the actual values in the ordering column—rows with the same value are treated as peers.
val dataWithTies = Seq(
(1, 100), (2, 100), (3, 100), (4, 200), (5, 200), (6, 300)
).toDF("id", "value")
val rowFrame = Window.orderBy("value").rowsBetween(Window.unboundedPreceding, Window.currentRow)
val rangeFrame = Window.orderBy("value").rangeBetween(Window.unboundedPreceding, Window.currentRow)
dataWithTies
.withColumn("row_sum", sum("id").over(rowFrame))
.withColumn("range_sum", sum("id").over(rangeFrame))
.show()
// +---+-----+-------+---------+
// | id|value|row_sum|range_sum|
// +---+-----+-------+---------+
// | 1| 100| 1| 6| // range includes all rows with value <= 100
// | 2| 100| 3| 6|
// | 3| 100| 6| 6|
// | 4| 200| 10| 15| // range includes all rows with value <= 200
// | 5| 200| 15| 15|
// | 6| 300| 21| 21|
// +---+-----+-------+---------+
With rowsBetween, each row gets a different running sum based on physical position. With rangeBetween, all rows with value 100 get the same sum because they’re peers in the ordering. Use rowsBetween for positional calculations and rangeBetween when logical value ranges matter.
Performance Considerations
Window functions can be expensive. Each partition must fit in memory on a single executor, and Spark must shuffle data to colocate partition members.
Watch for partition skew. If one category has 90% of your data, that partition becomes a bottleneck. Consider adding secondary partition keys or pre-filtering large partitions.
Window functions often replace self-joins, yielding significant performance gains:
// Slow: Self-join to get previous value
val withPrevJoin = dailySales.as("curr")
.join(
dailySales.as("prev"),
$"curr.date" === date_add($"prev.date", 1),
"left"
)
.select($"curr.date", $"curr.sales", $"prev.sales".as("prev_sales"))
// Fast: Window function
val withPrevWindow = dailySales
.withColumn("prev_sales", lag("sales", 1).over(Window.orderBy("date")))
The window function version avoids a shuffle join entirely. Check the Spark UI’s SQL tab to compare execution plans—window functions should show a single stage with a Window operator rather than multiple stages with exchange operators.
One caveat: if you need the same window calculation multiple times, Spark won’t always cache the window computation. Extract complex window results to intermediate DataFrames when reusing them extensively.
Window functions are one of Spark’s most powerful features for analytics workloads. Master the three components—partitioning, ordering, and frames—and you’ll write cleaner, faster data transformations.