PySpark - Sample DataFrame (Random Rows)
Sampling DataFrames is a fundamental operation in PySpark that you'll use constantly—whether you're testing transformations on a subset of production data, exploring unfamiliar datasets, or creating...
Key Insights
- PySpark’s
sample()method uses fractions rather than exact counts, requiring calculation when you need a specific number of rows—uselimit()withorderBy(rand())for precise control at the cost of performance - Stratified sampling via
sampleBy()lets you specify different sampling rates per category, essential for maintaining class distribution in imbalanced datasets used for machine learning - Always set a seed parameter for reproducible sampling in production pipelines, and consider repartitioning small samples to avoid inefficient parallel processing across too many partitions
Introduction
Sampling DataFrames is a fundamental operation in PySpark that you’ll use constantly—whether you’re testing transformations on a subset of production data, exploring unfamiliar datasets, or creating train/test splits for machine learning. Unlike pandas where you might load entire datasets into memory, PySpark’s distributed nature makes sampling both more critical and more nuanced.
PySpark provides three primary methods for sampling: sample() for basic random sampling, sampleBy() for stratified sampling, and randomSplit() for dividing DataFrames into multiple subsets. Each serves different purposes, and choosing the right approach impacts both correctness and performance.
Basic Sampling with sample()
The sample() method is your go-to for basic random sampling. It takes three key parameters: withReplacement (boolean), fraction (float between 0 and 1), and seed (integer for reproducibility).
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("sampling").getOrCreate()
# Create sample DataFrame
data = [(i, f"user_{i}", i * 10) for i in range(1, 10001)]
df = spark.createDataFrame(data, ["id", "username", "score"])
# Sample 10% of rows without replacement
sample_df = df.sample(withReplacement=False, fraction=0.1, seed=42)
print(f"Original count: {df.count()}")
print(f"Sample count: {sample_df.count()}")
The fraction parameter is probabilistic, not deterministic. A fraction of 0.1 doesn’t guarantee exactly 10% of rows—it means each row has a 10% chance of inclusion. With 10,000 rows, you’ll get approximately 1,000 rows, but it could be 980 or 1,020.
Sampling with replacement is useful for bootstrapping techniques:
# Sample with replacement - rows can appear multiple times
bootstrap_sample = df.sample(withReplacement=True, fraction=0.5, seed=123)
# Some rows will appear multiple times
bootstrap_sample.groupBy("id").count().filter("count > 1").show(5)
Always use the seed parameter in production code. Without it, you’ll get different samples each time, making debugging and validation impossible:
# Reproducible sampling
sample1 = df.sample(False, 0.1, seed=42)
sample2 = df.sample(False, 0.1, seed=42)
# These will be identical
# Non-reproducible sampling
sample3 = df.sample(False, 0.1)
sample4 = df.sample(False, 0.1)
# These will differ
Sampling Exact Number of Rows
The probabilistic nature of sample() is problematic when you need exactly N rows. You need to calculate the fraction based on your dataset size:
def sample_n_rows(df, n, seed=None):
"""Sample exactly n rows (approximately)"""
total_count = df.count()
fraction = n / total_count
return df.sample(False, fraction, seed=seed)
# Get approximately 500 rows
sample_500 = sample_n_rows(df, 500, seed=42)
print(f"Requested: 500, Got: {sample_500.count()}")
This still won’t give you exactly 500 rows. For precise control, use limit() with orderBy(rand()):
from pyspark.sql.functions import rand
# Get exactly 500 random rows
exact_sample = df.orderBy(rand(seed=42)).limit(500)
print(f"Exact count: {exact_sample.count()}") # Always 500
However, this approach has significant performance implications. The orderBy(rand()) forces a full shuffle of your data, which is expensive for large DataFrames. Use sample() for large datasets where approximate counts are acceptable, and reserve limit() for smaller datasets or when precision is critical:
# Performance comparison on large dataset
large_df = spark.range(0, 10000000) # 10 million rows
# Fast but approximate
import time
start = time.time()
sample_fast = large_df.sample(False, 0.01, seed=42)
sample_fast.count()
print(f"sample() time: {time.time() - start:.2f}s")
# Slow but exact
start = time.time()
sample_exact = large_df.orderBy(rand(42)).limit(100000)
sample_exact.count()
print(f"orderBy().limit() time: {time.time() - start:.2f}s")
Stratified Sampling
Stratified sampling maintains the proportional representation of different groups in your data. This is crucial for imbalanced datasets where random sampling might underrepresent minority classes.
The sampleBy() method takes a column name and a dictionary mapping category values to sampling fractions:
# Create imbalanced dataset
from pyspark.sql.functions import when
df_imbalanced = spark.range(0, 10000).withColumn(
"category",
when((col("id") % 10) < 7, "common")
.when((col("id") % 10) < 9, "rare")
.otherwise("very_rare")
)
# Check distribution
df_imbalanced.groupBy("category").count().show()
# Stratified sampling - 10% from each category
fractions = {
"common": 0.1,
"rare": 0.1,
"very_rare": 0.1
}
stratified_sample = df_imbalanced.sampleBy("category", fractions, seed=42)
stratified_sample.groupBy("category").count().show()
You can also oversample minority classes to create balanced datasets:
from pyspark.sql.functions import col
# Oversample rare categories
fractions_balanced = {
"common": 0.05, # Undersample majority
"rare": 0.3, # Sample more
"very_rare": 1.0 # Take all rows
}
balanced_sample = df_imbalanced.sampleBy("category", fractions_balanced, seed=42)
balanced_sample.groupBy("category").count().show()
This technique is invaluable for machine learning when you need balanced training data but want to preserve the original distribution in your test set.
Splitting DataFrames with randomSplit()
The randomSplit() method divides a DataFrame into multiple subsets in one operation. It takes a list of weights and an optional seed:
# 70/30 train-test split
train_df, test_df = df.randomSplit([0.7, 0.3], seed=42)
print(f"Training rows: {train_df.count()}")
print(f"Test rows: {test_df.count()}")
The weights don’t need to sum to 1.0—they’re normalized automatically:
# These are equivalent
split1 = df.randomSplit([0.7, 0.3], seed=42)
split2 = df.randomSplit([7, 3], seed=42)
split3 = df.randomSplit([70, 30], seed=42)
For machine learning pipelines, you typically need three splits:
# 60% train, 20% validation, 20% test
train_df, val_df, test_df = df.randomSplit([0.6, 0.2, 0.2], seed=42)
print(f"Train: {train_df.count()}")
print(f"Validation: {val_df.count()}")
print(f"Test: {test_df.count()}")
# Verify total
total = train_df.count() + val_df.count() + test_df.count()
print(f"Total: {total}, Original: {df.count()}")
Like sample(), the splits are probabilistic. You won’t get exactly 60/20/20, but close approximations.
Performance Considerations & Best Practices
Sampling affects your DataFrame’s partition structure. A 1% sample of a 1TB dataset across 1000 partitions results in ~10GB across 1000 partitions—most partitions will be nearly empty, creating inefficient processing:
# Check partition distribution
sample_small = df.sample(False, 0.01, seed=42)
print(f"Partitions: {sample_small.rdd.getNumPartitions()}")
# Repartition for efficiency
sample_optimized = sample_small.repartition(10)
print(f"Optimized partitions: {sample_optimized.rdd.getNumPartitions()}")
Cache samples you’ll reuse multiple times, especially after expensive operations:
# Sample once, use many times
sample_cached = df.sample(False, 0.1, seed=42).cache()
# Multiple operations on same sample
result1 = sample_cached.filter(col("score") > 50).count()
result2 = sample_cached.groupBy("username").count()
result3 = sample_cached.agg({"score": "avg"})
For production ML pipelines, document your sampling strategy clearly:
# Production-ready sampling function
def create_ml_splits(df, train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=42):
"""
Split DataFrame into train/val/test sets.
Args:
df: Input DataFrame
train_frac: Training set fraction (default 0.7)
val_frac: Validation set fraction (default 0.15)
test_frac: Test set fraction (default 0.15)
seed: Random seed for reproducibility
Returns:
Tuple of (train_df, val_df, test_df)
"""
assert abs((train_frac + val_frac + test_frac) - 1.0) < 0.001, \
"Fractions must sum to 1.0"
train, val, test = df.randomSplit(
[train_frac, val_frac, test_frac],
seed=seed
)
# Optimize partitions for smaller splits
if val_frac < 0.2:
val = val.repartition(max(1, int(df.rdd.getNumPartitions() * val_frac)))
if test_frac < 0.2:
test = test.repartition(max(1, int(df.rdd.getNumPartitions() * test_frac)))
return train, val, test
train, val, test = create_ml_splits(df, seed=42)
Sampling is deceptively simple but critical to get right. Use sample() for quick exploration and approximate subsets, sampleBy() when category distribution matters, and randomSplit() for ML workflows. Always set seeds, repartition small samples, and remember that PySpark sampling is probabilistic—plan accordingly when exact counts matter.