PySpark - RDD reduceByKey with Examples
from pyspark.sql import SparkSession
Key Insights
reduceByKey()performs parallel aggregation on PairRDDs by merging values with the same key using an associative and commutative reduce function, making it significantly more efficient thangroupByKey()for large datasets- The transformation combines values locally on each partition before shuffling data across the network, reducing data transfer by up to 90% compared to alternatives that shuffle all values
- Understanding when to use
reduceByKey()versusaggregateByKey()orcombineByKey()depends on whether your initial value type matches your output type and if you need different logic for combining within and across partitions
Understanding reduceByKey Fundamentals
reduceByKey() is a transformation operation that works exclusively on PairRDDs (key-value pairs). It merges values associated with each key using a specified reduce function. The function you provide must be both associative and commutative because PySpark applies it in parallel across partitions.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("ReduceByKeyExample").getOrCreate()
sc = spark.sparkContext
# Create a simple PairRDD
data = [("apple", 1), ("banana", 2), ("apple", 3), ("banana", 4), ("apple", 5)]
rdd = sc.parallelize(data)
# Sum values for each key
result = rdd.reduceByKey(lambda x, y: x + y)
print(result.collect())
# Output: [('banana', 6), ('apple', 9)]
The lambda function lambda x, y: x + y takes two values and returns their sum. PySpark applies this function repeatedly until only one value remains per key.
Performance Comparison with groupByKey
The primary advantage of reduceByKey() is its optimization strategy. Unlike groupByKey(), which shuffles all values across the network before aggregation, reduceByKey() performs local aggregation first.
# Inefficient approach with groupByKey
data = [("user1", 100), ("user2", 200), ("user1", 150),
("user2", 250), ("user1", 300)]
rdd = sc.parallelize(data, 4)
# groupByKey shuffles all values
grouped = rdd.groupByKey().mapValues(sum)
print(grouped.collect())
# Output: [('user2', 450), ('user1', 550)]
# Efficient approach with reduceByKey
reduced = rdd.reduceByKey(lambda x, y: x + y)
print(reduced.collect())
# Output: [('user2', 450), ('user1', 550)]
For a dataset with 1 million records across 100 partitions, groupByKey() shuffles 1 million values, while reduceByKey() might shuffle only 100 pre-aggregated values (one per partition per key).
Word Count Implementation
The canonical MapReduce example demonstrates reduceByKey() in action:
text_data = [
"spark is fast",
"spark is powerful",
"python works with spark"
]
text_rdd = sc.parallelize(text_data)
# Split, map to pairs, and count
word_counts = (text_rdd
.flatMap(lambda line: line.split())
.map(lambda word: (word, 1))
.reduceByKey(lambda x, y: x + y))
print(word_counts.collect())
# Output: [('is', 2), ('spark', 3), ('fast', 1), ('powerful', 1),
# ('python', 1), ('works', 1), ('with', 1)]
This pattern—map to (key, 1) pairs then reduce—is fundamental to distributed counting operations.
Aggregating Complex Data Types
reduceByKey() works with any data type as long as your reduce function handles it correctly:
# Aggregating lists
purchases = [
("customer1", ["item1", "item2"]),
("customer2", ["item3"]),
("customer1", ["item4", "item5"]),
("customer2", ["item6"])
]
rdd = sc.parallelize(purchases)
combined = rdd.reduceByKey(lambda x, y: x + y)
print(combined.collect())
# Output: [('customer2', ['item3', 'item6']),
# ('customer1', ['item1', 'item2', 'item4', 'item5'])]
# Aggregating tuples (min, max, count)
metrics = [
("server1", (10, 10, 1)),
("server2", (20, 20, 1)),
("server1", (5, 15, 1)),
("server2", (25, 30, 1))
]
rdd = sc.parallelize(metrics)
aggregated = rdd.reduceByKey(
lambda x, y: (min(x[0], y[0]), max(x[1], y[1]), x[2] + y[2])
)
print(aggregated.collect())
# Output: [('server2', (20, 30, 2)), ('server1', (5, 15, 2))]
Financial Data Aggregation Example
Real-world scenarios often require multiple aggregations simultaneously:
# Transaction data: (account_id, (amount, transaction_count))
transactions = [
("ACC001", (100.50, 1)),
("ACC002", (250.75, 1)),
("ACC001", (75.25, 1)),
("ACC002", (150.00, 1)),
("ACC001", (200.00, 1))
]
rdd = sc.parallelize(transactions)
# Calculate total amount and transaction count per account
account_summary = rdd.reduceByKey(
lambda x, y: (x[0] + y[0], x[1] + y[1])
)
# Calculate average transaction amount
account_averages = account_summary.mapValues(
lambda v: {"total": v[0], "count": v[1], "average": v[0] / v[1]}
)
print(account_averages.collect())
# Output: [('ACC002', {'total': 400.75, 'count': 2, 'average': 200.375}),
# ('ACC001', {'total': 375.75, 'count': 3, 'average': 125.25})]
Custom Reduce Functions
For complex business logic, define named functions instead of lambdas:
def merge_user_sessions(session1, session2):
"""Merge two user session dictionaries."""
return {
"page_views": session1["page_views"] + session2["page_views"],
"duration": session1["duration"] + session2["duration"],
"last_active": max(session1["last_active"], session2["last_active"])
}
sessions = [
("user1", {"page_views": 5, "duration": 120, "last_active": 1000}),
("user2", {"page_views": 3, "duration": 90, "last_active": 2000}),
("user1", {"page_views": 8, "duration": 200, "last_active": 3000}),
("user2", {"page_views": 2, "duration": 45, "last_active": 2500})
]
rdd = sc.parallelize(sessions)
merged_sessions = rdd.reduceByKey(merge_user_sessions)
print(merged_sessions.collect())
# Output: [('user2', {'page_views': 5, 'duration': 135, 'last_active': 2500}),
# ('user1', {'page_views': 13, 'duration': 320, 'last_active': 3000})]
Partitioning Considerations
Control the number of output partitions to optimize downstream operations:
data = [("key" + str(i % 100), i) for i in range(10000)]
rdd = sc.parallelize(data, 8)
# Default partitioning (hash partitioning)
result1 = rdd.reduceByKey(lambda x, y: x + y)
print(f"Default partitions: {result1.getNumPartitions()}")
# Output: Default partitions: 8
# Specify number of partitions
result2 = rdd.reduceByKey(lambda x, y: x + y, numPartitions=4)
print(f"Custom partitions: {result2.getNumPartitions()}")
# Output: Custom partitions: 4
# Using custom partitioner
from pyspark import Partitioner
class CustomPartitioner(Partitioner):
def __init__(self, num_partitions):
self._num_partitions = num_partitions
def numPartitions(self):
return self._num_partitions
def getPartition(self, key):
return hash(key) % self._num_partitions
result3 = rdd.partitionBy(4, CustomPartitioner(4)).reduceByKey(lambda x, y: x + y)
When to Use reduceByKey vs Alternatives
Use reduceByKey() when your input and output value types are identical. For different types or more complex aggregation logic, consider alternatives:
# reduceByKey: input type = output type
sales = [("product1", 100), ("product2", 200), ("product1", 150)]
rdd = sc.parallelize(sales)
total_sales = rdd.reduceByKey(lambda x, y: x + y)
# aggregateByKey: when you need different types or zero value
# Calculate average (requires count and sum)
avg_sales = rdd.aggregateByKey(
(0, 0), # zero value: (sum, count)
lambda acc, value: (acc[0] + value, acc[1] + 1), # seq function
lambda acc1, acc2: (acc1[0] + acc2[0], acc1[1] + acc2[1]) # comb function
).mapValues(lambda v: v[0] / v[1] if v[1] > 0 else 0)
print(avg_sales.collect())
# Output: [('product2', 200.0), ('product1', 125.0)]
The reduce function must be associative (a + b) + c = a + (b + c) and commutative a + b = b + a to ensure consistent results across different execution orders in the distributed environment.