PySpark RDD Tutorial - Complete Guide with Examples

RDDs are the fundamental data structure in Apache Spark. They represent an immutable, distributed collection of objects that can be processed in parallel across a cluster. While DataFrames and...

Key Insights

  • RDDs are PySpark’s foundational data structure offering fault tolerance and distributed processing, but DataFrames should be your default choice for structured data due to Catalyst optimization—use RDDs only when you need fine-grained control or are working with unstructured data
  • Understanding the difference between transformations (lazy) and actions (eager) is critical for performance—chain transformations together and minimize actions to avoid unnecessary computation and data movement
  • Proper partitioning and caching strategies can dramatically improve RDD performance, but blindly caching everything wastes memory—cache only RDDs you’ll reuse multiple times in your pipeline

Introduction to RDDs (Resilient Distributed Datasets)

RDDs are the fundamental data structure in Apache Spark. They represent an immutable, distributed collection of objects that can be processed in parallel across a cluster. While DataFrames and Datasets have largely replaced RDDs for most use cases, understanding RDDs remains essential for advanced Spark programming and troubleshooting.

RDDs have three key characteristics that make them powerful:

Immutability: Once created, RDDs cannot be modified. Every transformation creates a new RDD, which enables safe parallel processing and easy fault recovery.

Partitioning: Data is automatically divided into partitions that can be processed on different nodes in your cluster. This enables true distributed computing.

Fault Tolerance: RDDs remember their lineage—the sequence of transformations used to build them. If a partition is lost, Spark can reconstruct it by replaying those transformations.

When should you use RDDs instead of DataFrames? Use RDDs when you need fine-grained control over data processing, when working with unstructured data like media files, or when implementing custom partitioning logic. For structured data with SQL-like operations, DataFrames are faster and more convenient.

from pyspark import SparkContext, SparkConf

# Initialize Spark
conf = SparkConf().setAppName("RDD Tutorial").setMaster("local[*]")
sc = SparkContext(conf=conf)

# Simple RDD creation from a list
numbers = sc.parallelize([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
print(f"Number of partitions: {numbers.getNumPartitions()}")
print(f"First element: {numbers.first()}")

Creating RDDs

PySpark provides multiple ways to create RDDs depending on your data source.

The parallelize() method converts Python collections into RDDs. This is useful for testing and small datasets:

# From a list
data = [("Alice", 34), ("Bob", 45), ("Charlie", 28)]
people_rdd = sc.parallelize(data, numSlices=4)  # 4 partitions

# From a range
range_rdd = sc.parallelize(range(1000), numSlices=8)

For production workloads, you’ll typically load data from external sources:

# Read text file (each line becomes an element)
log_rdd = sc.textFile("hdfs://path/to/logs/*.log")

# Read multiple text files, preserving filenames
files_rdd = sc.wholeTextFiles("s3://bucket/data/")
# Returns (filename, content) tuples

# Read with specific number of partitions
csv_rdd = sc.textFile("data/sales.csv", minPartitions=16)

# Skip header line
header = csv_rdd.first()
data_rdd = csv_rdd.filter(lambda line: line != header)

You can also create RDDs from existing RDDs through transformations, which we’ll explore next.

RDD Transformations

Transformations are lazy operations that define a new RDD based on an existing one. They’re not executed until an action is called, allowing Spark to optimize the execution plan.

Narrow transformations operate on single partitions independently:

# map: Apply function to each element
numbers = sc.parallelize([1, 2, 3, 4, 5])
squared = numbers.map(lambda x: x ** 2)
# Result: [1, 4, 9, 16, 25]

# filter: Keep elements matching a condition
evens = numbers.filter(lambda x: x % 2 == 0)
# Result: [2, 4]

# flatMap: Map then flatten results
words = sc.parallelize(["hello world", "spark rdd"])
all_words = words.flatMap(lambda line: line.split())
# Result: ["hello", "world", "spark", "rdd"]

Wide transformations require data shuffling across partitions, making them more expensive:

# Classic word count example
text = sc.textFile("data/shakespeare.txt")
word_counts = (text
    .flatMap(lambda line: line.lower().split())
    .map(lambda word: (word, 1))
    .reduceByKey(lambda a, b: a + b)
    .sortBy(lambda pair: pair[1], ascending=False))

# Take top 10 most common words
top_words = word_counts.take(10)

Here’s a practical data transformation pipeline:

# Process and clean sales data
sales_raw = sc.textFile("data/sales.csv")

sales_clean = (sales_raw
    .filter(lambda line: not line.startswith("date"))  # Skip header
    .map(lambda line: line.split(","))
    .filter(lambda fields: len(fields) == 4)  # Validate structure
    .map(lambda fields: {
        "date": fields[0],
        "product": fields[1],
        "quantity": int(fields[2]),
        "revenue": float(fields[3])
    }))

RDD Actions

Actions trigger the actual computation and return results to the driver or write to storage. Use them sparingly as each action causes Spark to execute the entire transformation chain.

numbers = sc.parallelize(range(1, 101))

# Basic actions
print(numbers.count())           # 100
print(numbers.first())           # 1
print(numbers.take(5))          # [1, 2, 3, 4, 5]
print(numbers.top(3))           # [100, 99, 98]

# Aggregate actions
total = numbers.reduce(lambda a, b: a + b)  # Sum all numbers
print(f"Sum: {total}")  # 5050

# Statistics
stats = numbers.stats()
print(f"Mean: {stats.mean()}, StdDev: {stats.stdev()}")

# Collect all data (dangerous with large datasets!)
all_data = numbers.collect()  # Returns Python list

Save results to persistent storage:

# Save as text file
word_counts.saveAsTextFile("output/word_counts")

# Save with custom formatting
word_counts.map(lambda pair: f"{pair[0]}\t{pair[1]}").saveAsTextFile("output/formatted")

Working with Key-Value Pairs

Pair RDDs (RDDs of key-value tuples) unlock powerful grouping and aggregation operations.

# Create pair RDD from log data
logs = sc.textFile("data/access.log")

# Extract (IP address, bytes sent) pairs
ip_bytes = logs.map(lambda line: (
    line.split()[0],  # IP address
    int(line.split()[-1])  # Bytes sent
))

# Aggregate by key
total_bytes_by_ip = ip_bytes.reduceByKey(lambda a, b: a + b)

# Get average bytes per IP
count_by_ip = ip_bytes.mapValues(lambda x: (x, 1))
sums_and_counts = count_by_ip.reduceByKey(
    lambda a, b: (a[0] + b[0], a[1] + b[1])
)
averages = sums_and_counts.mapValues(lambda x: x[0] / x[1])

Advanced key-value operations:

# Sample data: user purchases
purchases = sc.parallelize([
    ("user1", "laptop"),
    ("user2", "mouse"),
    ("user1", "keyboard"),
    ("user3", "monitor"),
    ("user2", "laptop")
])

# Group all values by key
grouped = purchases.groupByKey()
user_items = grouped.mapValues(list)
# Result: [("user1", ["laptop", "keyboard"]), ...]

# Join operations
user_ages = sc.parallelize([("user1", 25), ("user2", 30), ("user3", 35)])
user_data = purchases.join(user_ages)
# Result: [("user1", ("laptop", 25)), ("user1", ("keyboard", 25)), ...]

RDD Partitioning and Performance

Partitioning directly impacts parallelism and performance. Understanding and controlling it is crucial for production workloads.

# Check current partitioning
data = sc.textFile("large_file.txt")
print(f"Partitions: {data.getNumPartitions()}")

# Increase partitions for more parallelism
data_repartitioned = data.repartition(100)

# Decrease partitions (more efficient than repartition)
data_coalesced = data.coalesce(10)

# Custom partitioner for skewed data
def custom_partitioner(key):
    # Distribute based on first letter
    return ord(key[0].lower()) % 10

word_pairs = words.map(lambda w: (w, 1))
partitioned = word_pairs.partitionBy(10, custom_partitioner)

Caching strategies dramatically improve performance for iterative algorithms:

from pyspark import StorageLevel

# Cache in memory (default)
frequent_data = sc.textFile("data.txt").cache()

# Persist with specific storage level
# MEMORY_ONLY: Fast but can lose data if not enough RAM
# MEMORY_AND_DISK: Spill to disk when memory is full
# DISK_ONLY: Slower but reliable
expensive_rdd = data.map(complex_transformation).persist(StorageLevel.MEMORY_AND_DISK)

# Use the cached RDD multiple times
result1 = expensive_rdd.filter(condition1).count()
result2 = expensive_rdd.filter(condition2).count()

# Unpersist when done
expensive_rdd.unpersist()

Real-World Example: Building a Data Pipeline

Let’s build a complete ETL pipeline that processes e-commerce transaction logs:

from datetime import datetime

def parse_log_line(line):
    """Parse log line with error handling"""
    try:
        parts = line.split("|")
        return {
            "timestamp": datetime.strptime(parts[0], "%Y-%m-%d %H:%M:%S"),
            "user_id": parts[1],
            "action": parts[2],
            "product_id": parts[3],
            "amount": float(parts[4]) if len(parts) > 4 else 0.0
        }
    except Exception as e:
        return None

# Load and parse logs
raw_logs = sc.textFile("s3://bucket/logs/2024-*.log")
parsed_logs = raw_logs.map(parse_log_line).filter(lambda x: x is not None)
parsed_logs.cache()  # We'll use this multiple times

# Calculate daily revenue by product
purchases = parsed_logs.filter(lambda log: log["action"] == "purchase")
daily_revenue = (purchases
    .map(lambda log: ((log["timestamp"].date(), log["product_id"]), log["amount"]))
    .reduceByKey(lambda a, b: a + b)
    .map(lambda pair: f"{pair[0][0]},{pair[0][1]},{pair[1]:.2f}"))

daily_revenue.saveAsTextFile("output/daily_revenue")

# Find most active users
user_activity = (parsed_logs
    .map(lambda log: (log["user_id"], 1))
    .reduceByKey(lambda a, b: a + b)
    .sortBy(lambda pair: pair[1], ascending=False))

top_users = user_activity.take(100)

# Calculate conversion rate by hour
by_hour = parsed_logs.map(lambda log: (log["timestamp"].hour, log["action"]))
conversions = (by_hour
    .map(lambda pair: (pair[0], (1 if pair[1] == "purchase" else 0, 1)))
    .reduceByKey(lambda a, b: (a[0] + b[0], a[1] + b[1]))
    .mapValues(lambda v: v[0] / v[1] if v[1] > 0 else 0)
    .sortByKey())

print("Conversion rates by hour:")
for hour, rate in conversions.collect():
    print(f"{hour:02d}:00 - {rate:.2%}")

# Cleanup
parsed_logs.unpersist()
sc.stop()

This pipeline demonstrates key RDD concepts: parsing and validation, caching for reuse, key-value aggregations, and multiple output formats. The error handling in parse_log_line ensures malformed records don’t crash your job—critical for production systems.

Remember: RDDs give you power and flexibility, but that comes with responsibility. Profile your jobs, monitor partition skew, and always consider whether DataFrames might be a better fit for your use case.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.