PySpark - SQL Window Functions

Window functions are one of PySpark's most powerful features for analytical queries. Unlike traditional GROUP BY aggregations that collapse multiple rows into a single result, window functions...

Key Insights

  • Window functions compute aggregations across a set of rows related to the current row without collapsing the result set, unlike GROUP BY which reduces rows to one per group
  • Proper window specification with partitionBy() and orderBy() is critical for both correctness and performance—poorly partitioned windows can cause severe data skew and OOM errors
  • Ranking and analytical functions like row_number(), lag(), and lead() eliminate complex self-joins and enable elegant solutions for problems like top-N-per-group and time-series comparisons

Introduction to Window Functions

Window functions are one of PySpark’s most powerful features for analytical queries. Unlike traditional GROUP BY aggregations that collapse multiple rows into a single result, window functions perform calculations across a set of rows while preserving the original row structure. This makes them ideal for ranking, running totals, moving averages, and comparing values across related rows.

The fundamental difference is simple: GROUP BY reduces your dataset, while window functions enrich it. When you need to calculate an aggregate value but keep all your original rows intact, window functions are the answer.

Here’s a concrete example showing the difference:

from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.functions import sum, avg, col

spark = SparkSession.builder.appName("WindowFunctions").getOrCreate()

# Sample sales data
data = [
    ("Electronics", "Laptop", 1200),
    ("Electronics", "Mouse", 25),
    ("Electronics", "Keyboard", 75),
    ("Clothing", "Shirt", 30),
    ("Clothing", "Pants", 50),
    ("Clothing", "Jacket", 120)
]

df = spark.createDataFrame(data, ["category", "product", "price"])

# GROUP BY - collapses to one row per category
grouped = df.groupBy("category").agg(avg("price").alias("avg_price"))
grouped.show()
# +----------+---------+
# |  category|avg_price|
# +----------+---------+
# |Electronics|   433.33|
# |  Clothing|    66.67|
# +----------+---------+

# Window function - keeps all rows, adds average as new column
window_spec = Window.partitionBy("category")
windowed = df.withColumn("category_avg", avg("price").over(window_spec))
windowed.show()
# +----------+--------+-----+------------+
# |  category| product|price|category_avg|
# +----------+--------+-----+------------+
# |Electronics|  Laptop| 1200|      433.33|
# |Electronics|   Mouse|   25|      433.33|
# |Electronics|Keyboard|   75|      433.33|
# |  Clothing|   Shirt|   30|       66.67|
# |  Clothing|   Pants|   50|       66.67|
# |  Clothing|  Jacket|  120|       66.67|
# +----------+--------+-----+------------+

Creating Window Specifications

The Window class defines how to partition and order your data for window function calculations. Think of it as defining the “view” each row has of related rows.

Three key components define a window specification:

  1. partitionBy(): Divides data into groups (like GROUP BY)
  2. orderBy(): Defines row order within each partition
  3. rowsBetween()/rangeBetween(): Defines the window frame (which rows to include)
from pyspark.sql import Window
from pyspark.sql.functions import row_number, sum, col

# Basic window: partition by category, order by price
basic_window = Window.partitionBy("category").orderBy("price")

# Window with frame specification - preceding and following rows
frame_window = Window.partitionBy("category").orderBy("date") \
    .rowsBetween(-6, 0)  # Current row and 6 preceding rows

# Unbounded frame - from start to current row
unbounded_window = Window.partitionBy("category").orderBy("date") \
    .rowsBetween(Window.unboundedPreceding, 0)

The difference between rowsBetween() and rangeBetween() is subtle but important:

  • rowsBetween(): Physical row positions (e.g., “3 rows before current”)
  • rangeBetween(): Logical range based on values (e.g., “all rows where value is within 100 of current”)
# Sample time-series data
from datetime import datetime, timedelta

dates = [(datetime(2024, 1, i), 100 + i*10) for i in range(1, 8)]
ts_df = spark.createDataFrame(dates, ["date", "sales"])

# ROWS: exactly 2 preceding rows
rows_window = Window.orderBy("date").rowsBetween(-2, 0)
ts_df.withColumn("sum_rows", sum("sales").over(rows_window)).show()

# RANGE: all rows within 2 days (based on date value)
range_window = Window.orderBy(col("date").cast("long")).rangeBetween(-172800, 0)
ts_df.withColumn("sum_range", sum("sales").over(range_window)).show()

Ranking Functions

Ranking functions assign positions to rows within partitions. They’re essential for top-N queries, percentile analysis, and competitive rankings.

from pyspark.sql.functions import row_number, rank, dense_rank, ntile

# Sample product sales data
sales_data = [
    ("Electronics", "Laptop", 1500),
    ("Electronics", "Phone", 800),
    ("Electronics", "Tablet", 600),
    ("Electronics", "Headphones", 150),
    ("Clothing", "Jacket", 200),
    ("Clothing", "Shoes", 120),
    ("Clothing", "Shirt", 50),
    ("Clothing", "Hat", 30)
]

sales_df = spark.createDataFrame(sales_data, ["category", "product", "revenue"])

window_spec = Window.partitionBy("category").orderBy(col("revenue").desc())

ranked_df = sales_df \
    .withColumn("row_num", row_number().over(window_spec)) \
    .withColumn("rank", rank().over(window_spec)) \
    .withColumn("dense_rank", dense_rank().over(window_spec)) \
    .withColumn("quartile", ntile(4).over(window_spec))

ranked_df.show()

Key differences between ranking functions:

  • row_number(): Sequential numbering (1, 2, 3, 4…) - always unique
  • rank(): Gaps after ties (1, 2, 2, 4…)
  • dense_rank(): No gaps after ties (1, 2, 2, 3…)
  • ntile(n): Divides rows into n buckets

Practical use case - finding top 3 products per category:

top_products = ranked_df.filter(col("row_num") <= 3) \
    .select("category", "product", "revenue", "row_num")

top_products.show()
# +----------+----------+-------+-------+
# |  category|   product|revenue|row_num|
# +----------+----------+-------+-------+
# |  Clothing|    Jacket|    200|      1|
# |  Clothing|     Shoes|    120|      2|
# |  Clothing|     Shirt|     50|      3|
# |Electronics|    Laptop|   1500|      1|
# |Electronics|     Phone|    800|      2|
# |Electronics|    Tablet|    600|      3|
# +----------+----------+-------+-------+

Aggregate Window Functions

Aggregate functions over windows enable running totals, moving averages, and cumulative statistics—critical for financial analysis and time-series data.

from pyspark.sql.functions import sum, avg, min, max
from datetime import datetime

# Daily sales data
daily_sales = [
    (datetime(2024, 1, 1), 1000),
    (datetime(2024, 1, 2), 1200),
    (datetime(2024, 1, 3), 800),
    (datetime(2024, 1, 4), 1500),
    (datetime(2024, 1, 5), 1100),
    (datetime(2024, 1, 6), 1300),
    (datetime(2024, 1, 7), 900),
    (datetime(2024, 1, 8), 1400)
]

daily_df = spark.createDataFrame(daily_sales, ["date", "sales"])

# Running total - all rows from start to current
running_total_window = Window.orderBy("date") \
    .rowsBetween(Window.unboundedPreceding, 0)

# 7-day moving average
moving_avg_window = Window.orderBy("date").rowsBetween(-6, 0)

result = daily_df \
    .withColumn("running_total", sum("sales").over(running_total_window)) \
    .withColumn("moving_avg_7d", avg("sales").over(moving_avg_window)) \
    .withColumn("max_to_date", max("sales").over(running_total_window))

result.show()
# +----------+-----+-------------+-------------+-----------+
# |      date|sales|running_total|moving_avg_7d|max_to_date|
# +----------+-----+-------------+-------------+-----------+
# |2024-01-01| 1000|         1000|       1000.0|       1000|
# |2024-01-02| 1200|         2200|       1100.0|       1200|
# |2024-01-03|  800|         3000|       1000.0|       1200|
# |2024-01-04| 1500|         4500|       1125.0|       1500|
# |2024-01-05| 1100|         5600|       1120.0|       1500|
# |2024-01-06| 1300|         6900|       1150.0|       1500|
# |2024-01-07|  900|         7800|       1114.3|       1500|
# |2024-01-08| 1400|         9200|       1171.4|       1500|
# +----------+-----+-------------+-------------+-----------+

This pattern is invaluable for dashboards showing cumulative metrics alongside recent trends.

Analytical Functions

Analytical functions like lag(), lead(), first(), and last() let you access values from other rows within your partition without self-joins.

from pyspark.sql.functions import lag, lead, first, last

# Calculate day-over-day changes
window_ordered = Window.orderBy("date")

changes_df = daily_df \
    .withColumn("prev_day_sales", lag("sales", 1).over(window_ordered)) \
    .withColumn("next_day_sales", lead("sales", 1).over(window_ordered)) \
    .withColumn("day_over_day_change", 
                col("sales") - lag("sales", 1).over(window_ordered)) \
    .withColumn("change_pct", 
                ((col("sales") - lag("sales", 1).over(window_ordered)) / 
                 lag("sales", 1).over(window_ordered) * 100))

changes_df.show()

Comparing to first and last values in a partition:

# For each category, compare to best and worst performing product
category_window = Window.partitionBy("category") \
    .orderBy(col("revenue").desc())

comparison_df = sales_df \
    .withColumn("top_product_revenue", first("revenue").over(category_window)) \
    .withColumn("bottom_product_revenue", last("revenue").over(category_window)) \
    .withColumn("gap_from_top", 
                first("revenue").over(category_window) - col("revenue"))

comparison_df.show()

These functions eliminate the need for complex self-joins that would otherwise require multiple passes over the data.

Performance Considerations & Best Practices

Window functions can be expensive if not used carefully. Here are critical optimization strategies:

1. Partition wisely: Too few partitions cause data skew; too many cause excessive shuffling.

# Bad: Single partition processes all data on one executor
bad_window = Window.orderBy("date")

# Good: Partition by a high-cardinality column
good_window = Window.partitionBy("customer_id").orderBy("date")

# Better: Partition by time bucket for time-series
from pyspark.sql.functions import date_trunc
better_df = df.withColumn("month", date_trunc("month", "date"))
better_window = Window.partitionBy("month").orderBy("date")

2. Use specific window frames: Unbounded windows are expensive because they process all rows in the partition.

# Expensive: Unbounded window on large partitions
expensive = Window.partitionBy("user_id").orderBy("timestamp") \
    .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

# Efficient: Limited window frame
efficient = Window.partitionBy("user_id").orderBy("timestamp") \
    .rowsBetween(-30, 0)  # Only last 30 rows

3. Reuse window specifications: Define once, use multiple times.

# Efficient: Define window once
window = Window.partitionBy("category").orderBy("date")

result = df \
    .withColumn("running_total", sum("sales").over(window)) \
    .withColumn("row_num", row_number().over(window)) \
    .withColumn("moving_avg", avg("sales").over(window.rowsBetween(-6, 0)))

4. Consider alternatives: For simple top-N queries, approx_percentile or sorting with limit might be faster.

5. Monitor partition sizes: Use df.rdd.glom().map(len).collect() to check partition distribution and identify skew.

Window functions transform complex analytical queries from multi-step joins into elegant single-pass operations. Master partitioning, understand frame specifications, and monitor performance—your PySpark pipelines will be both more readable and more efficient.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.