PySpark - Running Total with Window Function
Running totals, or cumulative sums, are essential calculations in data analysis that show the accumulation of values over an ordered sequence. Unlike simple aggregations that collapse data into...
Key Insights
- Window functions in PySpark enable running totals without collapsing rows like groupBy, maintaining the original dataset structure while adding cumulative calculations
- The
rowsBetween(Window.unboundedPreceding, Window.currentRow)specification is critical for accurate running totals, defining the frame from the start of the partition to the current row - Partitioning window functions by high-cardinality columns without proper salting can cause severe data skew and performance degradation in distributed environments
Understanding Running Totals in PySpark
Running totals, or cumulative sums, are essential calculations in data analysis that show the accumulation of values over an ordered sequence. Unlike simple aggregations that collapse data into summary statistics, running totals preserve every row while adding a cumulative perspective. You’ll encounter them in sales dashboards tracking revenue growth, inventory systems monitoring stock levels, and financial applications calculating account balances over time.
PySpark’s window functions provide the mechanism to compute running totals efficiently across distributed datasets. The key advantage over traditional SQL or pandas approaches is that window functions operate on partitions of data across your cluster, enabling calculations on datasets far larger than a single machine’s memory.
Window Functions vs. GroupBy Aggregations
Window functions fundamentally differ from groupBy operations in how they handle data. When you use groupBy, PySpark collapses rows into aggregated summaries—you lose the granular detail. Window functions, however, perform calculations across a “window” of rows while preserving every individual record.
Think of window functions as adding a new calculated column to your existing DataFrame, where each row’s value depends on a defined set of related rows. This “window” is controlled by three specifications: partitioning (which rows belong together), ordering (the sequence within each partition), and framing (which subset of the partition to include in each calculation).
Here’s the basic syntax structure:
from pyspark.sql import Window
from pyspark.sql.functions import sum, col
# Define a window specification
window_spec = Window.partitionBy("category").orderBy("date")
# Apply an aggregation over the window
df.withColumn("running_total", sum("amount").over(window_spec))
Creating a Sample Dataset
Let’s build a realistic sales dataset to demonstrate running totals. We’ll track daily sales across different products and regions:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType
from datetime import date
spark = SparkSession.builder.appName("RunningTotalDemo").getOrCreate()
# Define schema
schema = StructType([
StructField("date", DateType(), True),
StructField("product", StringType(), True),
StructField("region", StringType(), True),
StructField("amount", IntegerType(), True)
])
# Sample data
data = [
(date(2024, 1, 1), "Laptop", "North", 1200),
(date(2024, 1, 2), "Laptop", "North", 1500),
(date(2024, 1, 3), "Laptop", "North", 900),
(date(2024, 1, 1), "Laptop", "South", 1100),
(date(2024, 1, 2), "Laptop", "South", 1300),
(date(2024, 1, 1), "Phone", "North", 800),
(date(2024, 1, 2), "Phone", "North", 950),
(date(2024, 1, 3), "Phone", "North", 1050),
(date(2024, 1, 1), "Phone", "South", 700),
(date(2024, 1, 2), "Phone", "South", 850),
]
sales_df = spark.createDataFrame(data, schema)
sales_df.show()
Implementing a Basic Running Total
The simplest running total calculates cumulative values across the entire ordered dataset. The critical component is the frame specification using rowsBetween:
from pyspark.sql import Window
from pyspark.sql.functions import sum
# Define window ordered by date, no partitioning
window_spec = Window.orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow)
# Calculate running total
running_total_df = sales_df.withColumn(
"cumulative_sales",
sum("amount").over(window_spec)
)
running_total_df.orderBy("date", "product", "region").show()
The rowsBetween(Window.unboundedPreceding, Window.currentRow) specification is crucial. It defines the frame as “from the first row in the partition up to and including the current row.” Without this explicit frame definition, PySpark uses a default frame that may not produce the running total you expect.
Partitioned Running Totals
Real-world scenarios typically require running totals within specific groups. You might want cumulative sales per product, per region, or per product-region combination. This is where partitionBy becomes essential:
# Running total per product
product_window = Window.partitionBy("product").orderBy("date").rowsBetween(
Window.unboundedPreceding, Window.currentRow
)
product_running_total = sales_df.withColumn(
"product_cumulative",
sum("amount").over(product_window)
)
product_running_total.orderBy("product", "date").show()
# Running total per product AND region
product_region_window = Window.partitionBy("product", "region").orderBy("date").rowsBetween(
Window.unboundedPreceding, Window.currentRow
)
detailed_running_total = sales_df.withColumn(
"detailed_cumulative",
sum("amount").over(product_region_window)
)
detailed_running_total.orderBy("product", "region", "date").show()
Notice how partitioning resets the running total for each group. When you partition by product, laptops and phones maintain separate cumulative calculations. Adding region to the partition further segments the data, giving you independent running totals for each product-region combination.
Advanced Window Frame Specifications
Window frames offer flexibility beyond simple running totals. Understanding rowsBetween versus rangeBetween unlocks more sophisticated calculations:
# Standard running total using rowsBetween
rows_window = Window.orderBy("date").rowsBetween(
Window.unboundedPreceding, Window.currentRow
)
# Range-based window (value-based, not row-based)
range_window = Window.orderBy("date").rangeBetween(
Window.unboundedPreceding, Window.currentRow
)
# Moving average over last 3 rows (including current)
moving_avg_window = Window.orderBy("date").rowsBetween(-2, Window.currentRow)
advanced_df = sales_df.withColumn("running_total", sum("amount").over(rows_window)) \
.withColumn("moving_avg", avg("amount").over(moving_avg_window))
advanced_df.orderBy("date").show()
The rowsBetween method operates on physical row positions, while rangeBetween works with logical value ranges. For running totals ordered by dates, rangeBetween handles ties (multiple rows with the same date) differently—it includes all rows with values less than or equal to the current row’s ordering value.
For moving averages versus running totals, the frame specification changes. A running total uses unboundedPreceding to the current row, while a 3-day moving average uses -2 (two rows back) to the current row.
Performance Considerations and Best Practices
Window functions trigger data shuffling across your Spark cluster, which can become a bottleneck with poor partitioning strategies. Here are critical optimization techniques:
Choose partition columns wisely. Partitioning by high-cardinality columns (like customer_id with millions of values) distributes work evenly but increases shuffle overhead. Partitioning by low-cardinality columns (like region with 4 values) reduces shuffle but may cause data skew if partitions are unevenly sized.
# Check execution plan
window_spec = Window.partitionBy("product").orderBy("date").rowsBetween(
Window.unboundedPreceding, Window.currentRow
)
result_df = sales_df.withColumn("running_total", sum("amount").over(window_spec))
# Examine physical plan to identify shuffle operations
result_df.explain()
Cache strategically. If you’re computing multiple window functions over the same DataFrame, cache the source data to avoid re-reading:
sales_df.cache()
# Multiple window calculations benefit from caching
df_with_totals = sales_df.withColumn("running_total", sum("amount").over(window1)) \
.withColumn("running_avg", avg("amount").over(window2))
Avoid unnecessary ordering. If your data is already sorted by the order column within partitions, Spark can skip the sort phase. When possible, write data in the order you’ll query it.
Monitor data skew. Use Spark UI to identify skewed partitions where one executor processes significantly more data than others. If you see skew, consider salting techniques or repartitioning before applying window functions.
Running totals with PySpark window functions are powerful tools for time-series analysis and cumulative metrics. Master the frame specifications, understand partitioning implications, and monitor performance to handle production-scale datasets effectively. The pattern of partitionBy for grouping, orderBy for sequence, and rowsBetween for framing gives you precise control over cumulative calculations while maintaining the distributed processing advantages of Spark.