How to Use Window Functions in PySpark
Window functions are one of the most powerful features in PySpark for analytical workloads. They let you perform calculations across a set of rows that are somehow related to the current row—without...
Key Insights
- Window functions let you perform calculations across related rows while preserving the original row structure—unlike
groupBy, which collapses rows into aggregates. - The window specification has three components: partitioning (which rows to group), ordering (how to sort within groups), and frame (which rows to include in the calculation).
- Proper partitioning is critical for performance; poorly partitioned window operations can cause data skew and out-of-memory errors on large datasets.
Introduction to Window Functions
Window functions are one of the most powerful features in PySpark for analytical workloads. They let you perform calculations across a set of rows that are somehow related to the current row—without collapsing those rows into a single result like groupBy does.
Consider a common scenario: you have a table of employee salaries and want to calculate each employee’s salary as a percentage of their department’s total. With groupBy, you’d need to aggregate first, then join back to the original data. With window functions, you do it in a single operation.
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql import functions as F
spark = SparkSession.builder.appName("WindowFunctions").getOrCreate()
# Sample data
data = [
("Engineering", "Alice", 95000),
("Engineering", "Bob", 85000),
("Engineering", "Carol", 90000),
("Sales", "Dave", 70000),
("Sales", "Eve", 75000),
]
df = spark.createDataFrame(data, ["department", "name", "salary"])
The key insight: window functions return a value for every input row. Your DataFrame keeps its shape.
Understanding the Window Specification
Every window function needs a window specification that defines three things:
- Partitioning: Which rows belong together (like
GROUP BY, but without collapsing) - Ordering: How rows are sorted within each partition
- Frame: Which rows relative to the current row are included in the calculation
You create a window specification using the Window class:
from pyspark.sql.window import Window
# Basic window: partition by department, order by salary descending
window_spec = Window.partitionBy("department").orderBy(F.desc("salary"))
# Window with explicit frame specification
window_with_frame = (
Window.partitionBy("department")
.orderBy("date")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
Not all components are required for every function. Ranking functions need ordering but the frame is implicit. Aggregate functions without ordering operate on the entire partition.
Ranking Functions
PySpark provides four ranking functions, each with distinct behavior:
row_number(): Assigns sequential integers starting at 1. No ties—arbitrary assignment for equal values.rank(): Same values get the same rank, but subsequent ranks skip. (1, 2, 2, 4)dense_rank(): Same values get the same rank, no gaps. (1, 2, 2, 3)ntile(n): Divides rows intonroughly equal buckets.
window_spec = Window.partitionBy("department").orderBy(F.desc("salary"))
ranked_df = df.withColumn("row_num", F.row_number().over(window_spec)) \
.withColumn("rank", F.rank().over(window_spec)) \
.withColumn("dense_rank", F.dense_rank().over(window_spec)) \
.withColumn("quartile", F.ntile(4).over(window_spec))
ranked_df.show()
The most common use case is finding the top N records per group:
# Get the highest-paid employee in each department
top_per_dept = (
df.withColumn("rn", F.row_number().over(window_spec))
.filter(F.col("rn") == 1)
.drop("rn")
)
Use row_number() when you need exactly N results per group. Use rank() or dense_rank() when you want to include all ties.
Aggregate Window Functions
Standard aggregate functions—sum, avg, count, min, max—work as window functions when you apply them over a window specification. This enables running totals, moving averages, and percentage calculations.
# Sales data over time
sales_data = [
("2024-01-01", "Electronics", 1000),
("2024-01-02", "Electronics", 1500),
("2024-01-03", "Electronics", 1200),
("2024-01-01", "Clothing", 500),
("2024-01-02", "Clothing", 600),
("2024-01-03", "Clothing", 550),
]
sales_df = spark.createDataFrame(sales_data, ["date", "category", "amount"])
# Running total within each category
running_window = (
Window.partitionBy("category")
.orderBy("date")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
sales_with_running = sales_df.withColumn(
"running_total", F.sum("amount").over(running_window)
)
For percentage of total calculations, use a window without ordering to aggregate across the entire partition:
# Each employee's salary as percentage of department total
dept_window = Window.partitionBy("department")
df_with_pct = df.withColumn(
"dept_total", F.sum("salary").over(dept_window)
).withColumn(
"pct_of_dept", F.round(F.col("salary") / F.col("dept_total") * 100, 2)
)
df_with_pct.show()
Notice the window has no orderBy. Without ordering, aggregate window functions operate on the entire partition, giving you the total for comparison.
Analytic Functions (Lead/Lag)
Analytic functions let you access values from other rows relative to the current row:
lag(col, n): Gets the value fromnrows before the current rowlead(col, n): Gets the value fromnrows after the current rowfirst_value(col): Gets the first value in the window framelast_value(col): Gets the last value in the window frame
These are invaluable for time-series analysis and change detection:
# Month-over-month sales change
time_window = Window.partitionBy("category").orderBy("date")
sales_with_change = sales_df.withColumn(
"prev_day_amount", F.lag("amount", 1).over(time_window)
).withColumn(
"daily_change", F.col("amount") - F.col("prev_day_amount")
).withColumn(
"pct_change",
F.round((F.col("amount") - F.col("prev_day_amount")) / F.col("prev_day_amount") * 100, 2)
)
sales_with_change.show()
You can provide a default value for when there’s no previous/next row:
# Default to 0 when there's no previous row
df.withColumn(
"prev_amount", F.lag("amount", 1, 0).over(time_window)
)
A common gotcha with last_value(): by default, the window frame ends at the current row, so last_value() just returns the current row’s value. You need an explicit frame to get the actual last value:
full_window = (
Window.partitionBy("category")
.orderBy("date")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)
df.withColumn("final_amount", F.last_value("amount").over(full_window))
Frame Specifications
Frame specifications define exactly which rows are included in a window calculation relative to the current row. You have two options:
rowsBetween(start, end): Physical offset based on row positionrangeBetween(start, end): Logical offset based on the ordering column’s value
The boundaries can be:
Window.unboundedPreceding: All rows beforeWindow.unboundedFollowing: All rows afterWindow.currentRow: The current row (value 0)- Any integer: That many rows before (negative) or after (positive)
# 7-day moving average
daily_sales = [
("2024-01-01", 100), ("2024-01-02", 150), ("2024-01-03", 120),
("2024-01-04", 180), ("2024-01-05", 90), ("2024-01-06", 200),
("2024-01-07", 160), ("2024-01-08", 140), ("2024-01-09", 170),
]
daily_df = spark.createDataFrame(daily_sales, ["date", "sales"])
# 7-day moving average (current row + 6 preceding)
moving_avg_window = (
Window.orderBy("date")
.rowsBetween(-6, Window.currentRow)
)
daily_df.withColumn(
"seven_day_avg", F.round(F.avg("sales").over(moving_avg_window), 2)
).show()
The difference between rowsBetween and rangeBetween matters when you have gaps in your data. rowsBetween(-6, 0) always includes exactly 7 rows (if available). rangeBetween would include all rows within a value range of the ordering column—useful for true calendar-day calculations when dates might be missing.
# Convert date to numeric for rangeBetween
from pyspark.sql.functions import datediff, lit, to_date
daily_df_with_days = daily_df.withColumn(
"day_num", datediff(to_date("date"), lit("2024-01-01"))
)
# Range-based: all sales within 6 days of current date value
range_window = Window.orderBy("day_num").rangeBetween(-6, 0)
Performance Considerations and Best Practices
Window functions can be expensive. Here’s how to keep them performant:
Partition wisely. Each partition is processed independently, so choose partition keys that distribute data evenly. A partition key with high cardinality (many unique values) spreads work across executors. A single partition key value containing millions of rows will bottleneck on one executor.
# Bad: Single partition processes all data on one executor
bad_window = Window.orderBy("date")
# Better: Partition by a column with reasonable cardinality
good_window = Window.partitionBy("region").orderBy("date")
Watch for skew. If one partition has vastly more data than others, that executor becomes a bottleneck. Monitor your job’s task durations—if one task takes 10x longer than others, you have skew. Consider salting your partition key or filtering out the problematic values for separate processing.
Prefer window functions over self-joins. A common anti-pattern is joining a table to itself to compare rows. Window functions almost always perform better:
# Anti-pattern: Self-join to get previous row
df_prev = df.alias("prev")
df_curr = df.alias("curr")
# Complex join logic...
# Better: Use lag()
df.withColumn("prev_value", F.lag("value", 1).over(window_spec))
Minimize the number of window specifications. Each distinct window specification triggers a separate shuffle. If you need multiple calculations, try to use the same window specification:
# One shuffle
window = Window.partitionBy("dept").orderBy("date")
df.withColumn("running_sum", F.sum("amount").over(window)) \
.withColumn("running_avg", F.avg("amount").over(window)) \
.withColumn("row_num", F.row_number().over(window))
Cache intermediate results. If you’re applying multiple different window specifications to the same DataFrame, consider caching after expensive transformations to avoid recomputation.
Window functions are essential for analytical workloads in PySpark. Master the window specification, understand when to use each function type, and pay attention to partitioning—you’ll handle most analytical requirements without resorting to complex joins or UDFs.