How to Plot the Precision-Recall Curve in Python
Precision-Recall (PR) curves visualize the trade-off between precision and recall across different classification thresholds. Unlike ROC curves that plot true positive rate against false positive...
Key Insights
- Precision-Recall curves are superior to ROC curves for imbalanced datasets because they focus exclusively on the positive class performance without being inflated by true negatives
- The optimal threshold on a PR curve depends entirely on your business context—whether false positives or false negatives are more costly determines where you should operate
- A no-skill classifier baseline on a PR curve is a horizontal line at y = (positive samples / total samples), not the diagonal line you see in ROC curves
Introduction to Precision-Recall Curves
Precision-Recall (PR) curves visualize the trade-off between precision and recall across different classification thresholds. Unlike ROC curves that plot true positive rate against false positive rate, PR curves show how precision changes as recall increases.
Use PR curves when working with imbalanced datasets where the positive class is rare—fraud detection, disease diagnosis, anomaly detection, or spam classification. ROC curves can be misleadingly optimistic in these scenarios because they factor in true negatives, which dominate imbalanced datasets. A classifier that predicts “negative” for everything can achieve a 99% true negative rate on a 99:1 imbalanced dataset, making the ROC curve look good despite the model being useless.
PR curves don’t suffer from this problem. They focus exclusively on how well your model identifies the minority class, making them the right choice when that’s what actually matters.
Understanding Precision and Recall Metrics
Precision measures what proportion of positive predictions are actually correct:
Precision = TP / (TP + FP)
Recall (also called sensitivity or true positive rate) measures what proportion of actual positives you successfully identified:
Recall = TP / (TP + FN)
Both metrics derive from the confusion matrix, which breaks down predictions into true positives (TP), true negatives (TN), false positives (FP), and false negatives (FN).
Here’s how to calculate these manually:
import numpy as np
def calculate_precision_recall(y_true, y_pred):
"""Calculate precision and recall from predictions."""
# Create confusion matrix components
tp = np.sum((y_true == 1) & (y_pred == 1))
fp = np.sum((y_true == 0) & (y_pred == 1))
fn = np.sum((y_true == 1) & (y_pred == 0))
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
return precision, recall
# Example with sample data
y_true = np.array([1, 0, 1, 1, 0, 1, 0, 0, 1, 0])
y_pred = np.array([1, 0, 1, 0, 0, 1, 1, 0, 1, 0])
precision, recall = calculate_precision_recall(y_true, y_pred)
print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f}")
The fundamental trade-off: increasing recall (catching more positives) typically decreases precision (more false positives sneak in), and vice versa.
Preparing Sample Data and Training a Model
Let’s create an imbalanced dataset simulating credit card fraud detection, where fraudulent transactions represent only 2% of all transactions:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
import numpy as np
# Generate imbalanced dataset
X, y = make_classification(
n_samples=10000,
n_features=20,
n_informative=15,
n_redundant=5,
n_classes=2,
weights=[0.98, 0.02], # 2% positive class
flip_y=0.01,
random_state=42
)
print(f"Positive class ratio: {np.sum(y) / len(y):.3f}")
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42, stratify=y
)
# Train a logistic regression model
lr_model = LogisticRegression(random_state=42, max_iter=1000)
lr_model.fit(X_train, y_train)
# Get prediction probabilities (not hard predictions)
y_proba = lr_model.predict_proba(X_test)[:, 1]
print(f"Probability range: [{y_proba.min():.3f}, {y_proba.max():.3f}]")
Notice we extract predict_proba() rather than predict(). The probabilities let us explore different classification thresholds, which is the entire point of a PR curve.
Computing Precision-Recall Values
Scikit-learn’s precision_recall_curve() function does the heavy lifting, computing precision and recall at every unique probability threshold in your predictions:
from sklearn.metrics import precision_recall_curve, average_precision_score
# Compute precision-recall pairs at different thresholds
precision, recall, thresholds = precision_recall_curve(y_test, y_proba)
# Calculate average precision (area under PR curve)
ap_score = average_precision_score(y_test, y_proba)
print(f"Number of thresholds: {len(thresholds)}")
print(f"Average Precision: {ap_score:.3f}")
print(f"\nFirst 5 thresholds:")
for i in range(5):
print(f"Threshold: {thresholds[i]:.3f}, Precision: {precision[i]:.3f}, Recall: {recall[i]:.3f}")
The function returns arrays where each element represents precision and recall at a specific threshold. Note that precision and recall have one more element than thresholds because they include the endpoint where recall=0.
Plotting the Precision-Recall Curve
Now for the visualization. Start with a basic plot, then enhance it with meaningful annotations:
import matplotlib.pyplot as plt
# Basic PR curve
plt.figure(figsize=(10, 6))
plt.plot(recall, precision, linewidth=2, label=f'Logistic Regression (AP={ap_score:.3f})')
# Add no-skill baseline
no_skill = np.sum(y_test) / len(y_test)
plt.plot([0, 1], [no_skill, no_skill], linestyle='--', label='No Skill Baseline', color='gray')
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.title('Precision-Recall Curve', fontsize=14)
plt.legend(loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
For a more informative version with threshold annotations:
fig, ax = plt.subplots(figsize=(12, 7))
# Plot PR curve
ax.plot(recall, precision, linewidth=2, label=f'PR Curve (AP={ap_score:.3f})')
# Annotate specific thresholds
threshold_points = [0.3, 0.5, 0.7]
for thresh in threshold_points:
# Find closest threshold
idx = np.argmin(np.abs(thresholds - thresh))
ax.scatter(recall[idx], precision[idx], s=100, zorder=5)
ax.annotate(f'θ={thresh}',
xy=(recall[idx], precision[idx]),
xytext=(10, -10), textcoords='offset points',
fontsize=10, bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))
# No-skill baseline
ax.plot([0, 1], [no_skill, no_skill], linestyle='--', label=f'No Skill ({no_skill:.3f})', color='gray')
ax.set_xlabel('Recall', fontsize=12)
ax.set_ylabel('Precision', fontsize=12)
ax.set_title('Precision-Recall Curve with Threshold Annotations', fontsize=14)
ax.legend(loc='best')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Comparing Multiple Models
Comparing different models on the same PR curve reveals which performs better at identifying the positive class:
# Train multiple models
models = {
'Logistic Regression': LogisticRegression(random_state=42, max_iter=1000),
'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, max_depth=10),
}
plt.figure(figsize=(12, 7))
for name, model in models.items():
# Train and predict
model.fit(X_train, y_train)
y_proba = model.predict_proba(X_test)[:, 1]
# Compute PR curve
precision, recall, _ = precision_recall_curve(y_test, y_proba)
ap = average_precision_score(y_test, y_proba)
# Plot
plt.plot(recall, precision, linewidth=2, label=f'{name} (AP={ap:.3f})')
# Add baseline
plt.plot([0, 1], [no_skill, no_skill], linestyle='--', label=f'No Skill ({no_skill:.3f})', color='gray')
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.title('Model Comparison: Precision-Recall Curves', fontsize=14)
plt.legend(loc='best')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
The model with the curve furthest toward the top-right corner (higher precision at every recall level) performs best.
Best Practices and Interpretation Tips
Reading the curve: The top-right corner represents perfect classification (100% precision and recall). The closer your curve hugs this corner, the better. The area under the curve (Average Precision) summarizes overall performance—higher is better.
Choosing thresholds: Don’t blindly use 0.5. Your threshold should reflect business costs. If false positives are expensive (like flagging legitimate transactions as fraud), operate at higher precision (lower recall). If false negatives are costly (missing actual fraud), prioritize recall.
When to use PR curves: Always use them for imbalanced datasets. If your positive class is under 10-20% of your data, ROC curves will mislead you.
Here’s a function to find the optimal threshold based on F1-score or custom business metrics:
from sklearn.metrics import f1_score
def find_optimal_threshold(y_true, y_proba, metric='f1', beta=1.0):
"""
Find optimal threshold based on specified metric.
Parameters:
- metric: 'f1' for F1-score, 'custom' for custom cost function
- beta: weight for F-beta score (beta > 1 favors recall, beta < 1 favors precision)
"""
precision, recall, thresholds = precision_recall_curve(y_true, y_proba)
if metric == 'f1':
# Calculate F-beta score for each threshold
fbeta_scores = ((1 + beta**2) * precision[:-1] * recall[:-1]) / \
(beta**2 * precision[:-1] + recall[:-1] + 1e-10)
optimal_idx = np.argmax(fbeta_scores)
optimal_threshold = thresholds[optimal_idx]
print(f"Optimal threshold: {optimal_threshold:.3f}")
print(f"F{beta}-score: {fbeta_scores[optimal_idx]:.3f}")
print(f"Precision: {precision[optimal_idx]:.3f}")
print(f"Recall: {recall[optimal_idx]:.3f}")
return optimal_threshold
# Example: Find threshold maximizing F1-score
optimal_thresh = find_optimal_threshold(y_test, y_proba, metric='f1')
# Example: Favor recall (F2-score)
optimal_thresh_recall = find_optimal_threshold(y_test, y_proba, metric='f1', beta=2.0)
The PR curve is your diagnostic tool for imbalanced classification. Use it to understand model behavior, compare alternatives, and make informed threshold decisions based on your specific business context. Don’t settle for default thresholds—your data deserves better.