PySpark - Lead and Lag Functions
Window functions operate on a subset of rows related to the current row, enabling calculations across row boundaries without collapsing the dataset like `groupBy()` does. Lead and lag functions are...
Key Insights
- Lead and lag functions eliminate expensive self-joins when comparing sequential rows, offering 3-10x performance improvements for time-series analysis in distributed environments.
- Proper window specification with
partitionBy()andorderBy()is non-negotiable—without explicit ordering, lead/lag produce non-deterministic results that will silently corrupt your analysis. - Combining lead and lag in a single pass enables complex sequential analytics like calculating durations, detecting anomalies, and identifying trend reversals without multiple DataFrame scans.
Introduction to Window Functions in PySpark
Window functions operate on a subset of rows related to the current row, enabling calculations across row boundaries without collapsing the dataset like groupBy() does. Lead and lag functions are the workhorses of sequential data analysis, allowing you to access values from subsequent or previous rows within a defined window.
The traditional approach to comparing row values requires self-joins—joining a DataFrame to itself with offset conditions. This is expensive in distributed computing environments. A self-join to compare today’s sales with yesterday’s requires shuffling potentially billions of rows across the cluster twice. Lead and lag functions perform these comparisons in a single pass over ordered partitions, dramatically reducing computational overhead.
These functions are indispensable for time-series analysis, calculating growth rates, detecting sequence gaps, analyzing user journeys, and any scenario where row context matters.
Understanding the Lag Function
The lag() function retrieves values from a previous row within the same partition. It accepts three parameters: the column to retrieve, the offset (number of rows back), and a default value when no previous row exists.
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, col
spark = SparkSession.builder.appName("LagExample").getOrCreate()
# Stock price dataset
stock_data = [
("AAPL", "2024-01-01", 150.0),
("AAPL", "2024-01-02", 152.5),
("AAPL", "2024-01-03", 148.0),
("AAPL", "2024-01-04", 151.0),
("GOOGL", "2024-01-01", 2800.0),
("GOOGL", "2024-01-02", 2825.0),
]
df = spark.createDataFrame(stock_data, ["ticker", "date", "price"])
# Define window specification
window_spec = Window.partitionBy("ticker").orderBy("date")
# Calculate day-over-day price change
df_with_lag = df.withColumn("prev_price", lag("price", 1).over(window_spec)) \
.withColumn("price_change", col("price") - col("prev_price")) \
.withColumn("pct_change",
((col("price") - col("prev_price")) / col("prev_price") * 100))
df_with_lag.show()
Output shows the previous day’s price in a new column, with null for the first row in each partition:
+------+----------+------+----------+------------+----------+
|ticker| date| price|prev_price|price_change|pct_change|
+------+----------+------+----------+------------+----------+
| AAPL|2024-01-01| 150.0| null| null| null|
| AAPL|2024-01-02| 152.5| 150.0| 2.5| 1.666667|
| AAPL|2024-01-03| 148.0| 152.5| -4.5| -2.950820|
| AAPL|2024-01-04| 151.0| 148.0| 3.0| 2.027027|
| GOOGL|2024-01-01|2800.0| null| null| null|
| GOOGL|2024-01-02|2825.0| 2800.0| 25.0| 0.892857|
+------+----------+------+----------+------------+----------+
The offset parameter defaults to 1 but can be any positive integer. Use lag("price", 7) to compare with the same day last week. The default value parameter handles edge cases—use lag("price", 1, 0) to replace nulls with zero instead of leaving them null.
Understanding the Lead Function
The lead() function mirrors lag() but accesses subsequent rows instead of previous ones. It’s essential for forward-looking analysis and forecasting comparisons.
from pyspark.sql.functions import lead
# Sales target dataset
sales_data = [
("Q1", "2024-01", 100000),
("Q1", "2024-02", 105000),
("Q1", "2024-03", 110000),
("Q2", "2024-04", 115000),
("Q2", "2024-05", 120000),
]
df_sales = spark.createDataFrame(sales_data, ["quarter", "month", "target"])
window_spec = Window.partitionBy("quarter").orderBy("month")
df_with_lead = df_sales.withColumn("next_target", lead("target", 1).over(window_spec)) \
.withColumn("target_increase", col("next_target") - col("target"))
df_with_lead.show()
Lead is particularly useful for calculating durations (time until next event), forecasting variance (actual vs. next period’s target), and identifying the last occurrence in a sequence (where lead() returns null).
Window Specification and Partitioning
The window specification defines the scope and ordering for lead/lag operations. partitionBy() creates independent groups, while orderBy() establishes the sequence within each partition.
from pyspark.sql.functions import lead, lag
# Employee salary history
employee_data = [
("Engineering", "Alice", "2022-01-15", 80000),
("Engineering", "Alice", "2023-01-15", 85000),
("Engineering", "Bob", "2022-03-01", 75000),
("Engineering", "Bob", "2023-03-01", 78000),
("Sales", "Charlie", "2022-02-01", 70000),
("Sales", "Charlie", "2023-02-01", 75000),
]
df_emp = spark.createDataFrame(employee_data,
["department", "name", "review_date", "salary"])
# Window partitioned by department AND employee, ordered by review date
window_spec = Window.partitionBy("department", "name").orderBy("review_date")
df_salary_changes = df_emp.withColumn("previous_salary",
lag("salary", 1, 0).over(window_spec)) \
.withColumn("next_salary",
lead("salary", 1).over(window_spec)) \
.withColumn("raise_amount",
col("salary") - col("previous_salary"))
df_salary_changes.show(truncate=False)
Partitioning resets the window boundaries. In this example, each employee’s salary history is independent—Alice’s lag doesn’t access Bob’s data even though they’re in the same department. Multi-column partitioning is common for hierarchical data: partition by region and store, or user and session.
The orderBy() clause is mandatory for deterministic results. Without it, PySpark may return different row orders across executions, making lead/lag values unpredictable.
Practical Applications and Advanced Patterns
Combining lead and lag enables sophisticated sequential analysis. Here’s a customer journey analysis calculating session durations and identifying drop-off points:
from pyspark.sql.functions import lead, lag, unix_timestamp, when
# User activity log
activity_data = [
("user_1", "2024-01-01 10:00:00", "login"),
("user_1", "2024-01-01 10:15:00", "page_view"),
("user_1", "2024-01-01 10:30:00", "logout"),
("user_1", "2024-01-01 14:00:00", "login"),
("user_1", "2024-01-01 14:45:00", "logout"),
("user_2", "2024-01-01 09:00:00", "login"),
("user_2", "2024-01-01 09:20:00", "page_view"),
]
df_activity = spark.createDataFrame(activity_data,
["user_id", "timestamp", "event"])
window_spec = Window.partitionBy("user_id").orderBy("timestamp")
df_sessions = df_activity.withColumn("next_event", lead("event", 1).over(window_spec)) \
.withColumn("next_timestamp", lead("timestamp", 1).over(window_spec)) \
.withColumn("prev_event", lag("event", 1).over(window_spec)) \
.withColumn("session_duration_min",
(unix_timestamp("next_timestamp") - unix_timestamp("timestamp")) / 60) \
.withColumn("is_dropout",
when((col("event") != "logout") & col("next_event").isNull(), True)
.otherwise(False))
df_sessions.show(truncate=False)
This pattern identifies incomplete sessions (users who didn’t log out) and calculates time between events. You can extend this to detect anomalies (gaps exceeding thresholds), calculate retention rates (comparing user activity across periods), or fill missing values using forward/backward fill strategies.
Performance Considerations and Best Practices
Window functions trigger a shuffle operation to group rows by partition. Optimize performance by:
1. Right-size your partitions: Too few partitions underutilize the cluster; too many create overhead. Aim for partitions of 100MB-200MB.
2. Minimize partition skew: If one partition has 90% of the data, that single executor becomes a bottleneck. Add salt to skewed keys or use multiple partition columns.
3. Reuse window specifications: Define the window once and apply it to multiple columns to avoid redundant shuffles.
# Inefficient - multiple shuffles
df.withColumn("lag1", lag("value", 1).over(Window.partitionBy("id").orderBy("date"))) \
.withColumn("lag2", lag("value", 2).over(Window.partitionBy("id").orderBy("date")))
# Efficient - single shuffle
window_spec = Window.partitionBy("id").orderBy("date")
df.withColumn("lag1", lag("value", 1).over(window_spec)) \
.withColumn("lag2", lag("value", 2).over(window_spec))
4. Use default values strategically: Replace nulls with meaningful defaults (lag("price", 1, 0)) to avoid null propagation in downstream calculations.
5. Limit partition size with filters: Apply where() clauses before window functions to reduce data volume.
Common Pitfalls and Troubleshooting
The most frequent error is omitting orderBy() from the window specification. This produces non-deterministic results:
# WRONG - no ordering
window_spec_bad = Window.partitionBy("category")
df.withColumn("prev_value", lag("value", 1).over(window_spec_bad))
# Results are unpredictable and change between runs
# CORRECT - explicit ordering
window_spec_good = Window.partitionBy("category").orderBy("timestamp")
df.withColumn("prev_value", lag("value", 1).over(window_spec_good))
# Results are deterministic
Another common issue is misunderstanding partition boundaries. Lead and lag never cross partition boundaries—the last row in a partition will have a null lead value, even if another partition follows.
Offset miscalculations occur when using dynamic offsets. If you need variable offsets based on data conditions, you’ll need conditional logic or multiple lag columns with coalesce():
# Get the most recent non-null value from the last 3 rows
from pyspark.sql.functions import coalesce
df.withColumn("lag1", lag("value", 1).over(window_spec)) \
.withColumn("lag2", lag("value", 2).over(window_spec)) \
.withColumn("lag3", lag("value", 3).over(window_spec)) \
.withColumn("last_valid", coalesce("lag1", "lag2", "lag3"))
Master lead and lag functions, and you’ll handle sequential analysis efficiently without the performance penalty of self-joins. The key is proper window specification, understanding partition boundaries, and leveraging default values to handle edge cases cleanly.