How to Perform Permutation Testing in Python
Permutation testing is a resampling method that lets you test hypotheses without assuming your data follows a specific distribution. Instead of relying on theoretical distributions like the...
Key Insights
- Permutation testing makes no assumptions about your data’s distribution, making it ideal when normality assumptions fail or sample sizes are small
- The core logic is simple: if there’s no real effect, shuffling group labels shouldn’t change your results—do this thousands of times to build a null distribution
- Python’s
scipy.stats.permutation_testhandles most use cases, but understanding the manual implementation helps you customize tests for non-standard statistics
Introduction to Permutation Testing
Permutation testing is a resampling method that lets you test hypotheses without assuming your data follows a specific distribution. Instead of relying on theoretical distributions like the t-distribution or F-distribution, you create your own null distribution by shuffling your data.
Use permutation tests when:
- Your data violates normality assumptions
- Sample sizes are small (where asymptotic approximations break down)
- You’re working with non-standard test statistics
- You want results that are easy to explain to non-statisticians
The core concept is straightforward. You have two groups and want to know if they differ. Under the null hypothesis (no difference), the group labels are meaningless—any observation could belong to either group. So you shuffle the labels, calculate your test statistic, and repeat thousands of times. This gives you a distribution of what your statistic looks like when there’s truly no effect.
The Logic Behind Permutation Tests
Consider a clinical trial comparing drug versus placebo. You observe a mean difference of 5 points between groups. Is this real or just noise?
The permutation approach asks: “If the drug had no effect, how often would we see a difference this large or larger just by chance?”
To answer this, you:
- Pool all observations together (ignoring original group labels)
- Randomly assign observations to two groups matching your original group sizes
- Calculate the mean difference for this shuffled data
- Repeat steps 2-3 many times (typically 10,000+)
- Count how often the shuffled differences are as extreme as your observed difference
The p-value is simply the proportion of permuted statistics that are at least as extreme as your observed statistic. If only 2% of permutations produce a difference as large as what you observed, your p-value is 0.02.
This approach is conceptually clean: you’re directly simulating what randomness looks like under the null hypothesis, then checking if your data looks unusual compared to that.
Implementing a Permutation Test from Scratch
Let’s build a permutation test for comparing two group means. Understanding the manual implementation gives you flexibility to test any statistic you want.
import numpy as np
def permutation_test_means(group1, group2, n_permutations=10000, random_state=42):
"""
Perform a two-sample permutation test for difference in means.
Returns the observed difference, null distribution, and two-tailed p-value.
"""
rng = np.random.default_rng(random_state)
# Observed test statistic
observed_diff = np.mean(group1) - np.mean(group2)
# Pool all observations
pooled = np.concatenate([group1, group2])
n1 = len(group1)
n_total = len(pooled)
# Build null distribution
null_distribution = np.zeros(n_permutations)
for i in range(n_permutations):
# Shuffle and split
shuffled = rng.permutation(pooled)
perm_group1 = shuffled[:n1]
perm_group2 = shuffled[n1:]
# Calculate statistic for this permutation
null_distribution[i] = np.mean(perm_group1) - np.mean(perm_group2)
# Two-tailed p-value
p_value = np.mean(np.abs(null_distribution) >= np.abs(observed_diff))
return observed_diff, null_distribution, p_value
# Example usage
np.random.seed(42)
treatment = np.random.normal(105, 15, size=25) # Treatment group
control = np.random.normal(100, 15, size=25) # Control group
observed, null_dist, p = permutation_test_means(treatment, control)
print(f"Observed difference: {observed:.2f}")
print(f"P-value: {p:.4f}")
This implementation is explicit about what’s happening at each step. The key insight is that we’re treating the combined data as exchangeable under the null hypothesis—any observation could have come from either group.
For a one-tailed test, you’d modify the p-value calculation:
# One-tailed (testing if group1 > group2)
p_value_one_tailed = np.mean(null_distribution >= observed_diff)
Using scipy.stats for Permutation Testing
For production code, use scipy.stats.permutation_test. It’s optimized, handles edge cases, and provides a clean API.
from scipy import stats
def mean_difference(x, y, axis):
"""Statistic function for permutation_test."""
return np.mean(x, axis=axis) - np.mean(y, axis=axis)
# Same data as before
np.random.seed(42)
treatment = np.random.normal(105, 15, size=25)
control = np.random.normal(100, 15, size=25)
result = stats.permutation_test(
(treatment, control),
statistic=mean_difference,
n_resamples=10000,
alternative='two-sided',
permutation_type='independent'
)
print(f"Observed statistic: {result.statistic:.2f}")
print(f"P-value: {result.pvalue:.4f}")
The permutation_type parameter matters:
'independent': For independent samples (our two-group comparison)'samples': For paired samples (before/after measurements)'pairings': For testing association between paired observations
Here’s a paired sample example:
# Paired data: same subjects measured before and after
before = np.array([85, 90, 78, 92, 88, 76, 95, 89, 84, 91])
after = np.array([88, 95, 82, 94, 91, 80, 98, 92, 87, 93])
def paired_mean_diff(x, y, axis):
return np.mean(x - y, axis=axis)
result_paired = stats.permutation_test(
(before, after),
statistic=paired_mean_diff,
n_resamples=10000,
alternative='less', # Testing if after > before
permutation_type='samples'
)
print(f"Mean improvement: {np.mean(after - before):.2f}")
print(f"P-value: {result_paired.pvalue:.4f}")
Permutation Testing for Correlation
Permutation tests work beautifully for correlation. Under the null hypothesis of no association, the pairing between x and y values is arbitrary—shuffling one variable shouldn’t matter.
def permutation_test_correlation(x, y, n_permutations=10000, random_state=42):
"""
Permutation test for Pearson correlation significance.
"""
rng = np.random.default_rng(random_state)
# Observed correlation
observed_r = np.corrcoef(x, y)[0, 1]
# Build null distribution by shuffling one variable
null_distribution = np.zeros(n_permutations)
for i in range(n_permutations):
y_shuffled = rng.permutation(y)
null_distribution[i] = np.corrcoef(x, y_shuffled)[0, 1]
# Two-tailed p-value
p_value = np.mean(np.abs(null_distribution) >= np.abs(observed_r))
return observed_r, null_distribution, p_value
# Example: testing correlation between study hours and test scores
np.random.seed(42)
study_hours = np.random.uniform(1, 10, size=30)
test_scores = 60 + 3 * study_hours + np.random.normal(0, 5, size=30)
r, null_dist, p = permutation_test_correlation(study_hours, test_scores)
print(f"Observed correlation: {r:.3f}")
print(f"P-value: {p:.4f}")
Using scipy for the same test:
def correlation_statistic(x, y, axis):
# Handle the batched computation scipy uses internally
if axis is None:
return np.corrcoef(x, y)[0, 1]
# For vectorized computation
x_centered = x - np.mean(x, axis=axis, keepdims=True)
y_centered = y - np.mean(y, axis=axis, keepdims=True)
numerator = np.sum(x_centered * y_centered, axis=axis)
denominator = np.sqrt(np.sum(x_centered**2, axis=axis) * np.sum(y_centered**2, axis=axis))
return numerator / denominator
result = stats.permutation_test(
(study_hours, test_scores),
statistic=correlation_statistic,
n_resamples=10000,
permutation_type='pairings'
)
print(f"P-value (scipy): {result.pvalue:.4f}")
Visualizing Results
A histogram of the null distribution with your observed statistic marked is the clearest way to communicate permutation test results.
import matplotlib.pyplot as plt
def plot_permutation_test(observed, null_distribution, p_value,
stat_name="Test Statistic"):
"""
Create a publication-ready visualization of permutation test results.
"""
fig, ax = plt.subplots(figsize=(10, 6))
# Histogram of null distribution
ax.hist(null_distribution, bins=50, density=True, alpha=0.7,
color='steelblue', edgecolor='white', label='Null Distribution')
# Observed statistic
ax.axvline(observed, color='crimson', linewidth=2.5,
linestyle='--', label=f'Observed = {observed:.3f}')
# Mirror for two-tailed
ax.axvline(-observed, color='crimson', linewidth=2.5,
linestyle=':', alpha=0.5)
# Shade rejection regions
extreme_mask = np.abs(null_distribution) >= np.abs(observed)
if np.any(extreme_mask):
extreme_vals = null_distribution[extreme_mask]
ax.hist(extreme_vals, bins=50, density=True, alpha=0.5,
color='crimson', edgecolor='white')
ax.set_xlabel(stat_name, fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.set_title(f'Permutation Test Results (p = {p_value:.4f})', fontsize=14)
ax.legend(loc='upper right', fontsize=10)
plt.tight_layout()
return fig, ax
# Generate and plot
observed, null_dist, p = permutation_test_means(treatment, control)
fig, ax = plot_permutation_test(observed, null_dist, p,
stat_name="Difference in Means")
plt.savefig('permutation_test_result.png', dpi=150)
plt.show()
Best Practices and Considerations
Number of permutations: For exploratory analysis, 1,000 permutations is often sufficient. For publication, use at least 10,000. For very small p-values, you need more permutations—to reliably detect p = 0.001, you need at least 10,000 permutations.
Exact vs. approximate tests: For small samples, you can enumerate all possible permutations. With n₁ = 5 and n₂ = 5, there are only 252 unique permutations. Use scipy.stats.permutation_test with n_resamples=np.inf for exact tests when computationally feasible.
Computational efficiency: Permutation tests can be slow. Vectorize when possible:
def fast_permutation_test(group1, group2, n_permutations=10000):
"""Vectorized permutation test - much faster for large n_permutations."""
pooled = np.concatenate([group1, group2])
n1, n_total = len(group1), len(pooled)
observed = np.mean(group1) - np.mean(group2)
# Generate all permutation indices at once
rng = np.random.default_rng(42)
indices = np.array([rng.permutation(n_total) for _ in range(n_permutations)])
# Vectorized calculation
perm_group1_means = pooled[indices[:, :n1]].mean(axis=1)
perm_group2_means = pooled[indices[:, n1:]].mean(axis=1)
null_distribution = perm_group1_means - perm_group2_means
p_value = np.mean(np.abs(null_distribution) >= np.abs(observed))
return observed, null_distribution, p_value
Machine learning context: For model evaluation, use sklearn.inspection.permutation_importance or sklearn.model_selection.permutation_test_score:
from sklearn.model_selection import permutation_test_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=20, random_state=42)
clf = RandomForestClassifier(n_estimators=100, random_state=42)
score, perm_scores, p_value = permutation_test_score(
clf, X, y, cv=5, n_permutations=1000, random_state=42
)
print(f"True score: {score:.3f}, P-value: {p_value:.4f}")
When to use permutation tests: They shine when you have non-standard statistics, non-normal data, or small samples. They’re also excellent for teaching—the logic is intuitive and the results are easy to explain. However, they can be computationally expensive and may have less power than parametric tests when parametric assumptions are actually met.