How to Use Train-Test-Validation Split in Python
Data splitting is the foundation of honest machine learning model evaluation. Without proper splitting, you're essentially grading your own homework with the answer key in hand—your model's...
Key Insights
- Always split your data before any preprocessing to prevent data leakage—fit scalers and encoders only on training data, then transform validation and test sets
- Use stratified splitting for classification tasks to maintain class distribution across all sets, especially critical when dealing with imbalanced datasets
- For time series data, never use random splitting—maintain temporal order by splitting chronologically to avoid unrealistic future information leaking into your training process
Introduction to Data Splitting
Data splitting is the foundation of honest machine learning model evaluation. Without proper splitting, you’re essentially grading your own homework with the answer key in hand—your model’s performance metrics become meaningless.
The three-way split serves distinct purposes. The training set is where your model learns patterns from data. The validation set helps you tune hyperparameters and make architectural decisions without contaminating your final evaluation. The test set remains completely untouched until the end, providing an unbiased estimate of how your model will perform on new, unseen data.
Common split ratios include 60-20-20 or 70-15-15 (train-validation-test). For larger datasets, you might use 80-10-10 since you need fewer samples for reliable validation and testing. The key principle: your training set should be large enough to learn meaningful patterns, while validation and test sets must be large enough to provide statistically significant performance estimates.
# Conceptual flow of data through the three sets
# Raw Data (100%)
# ↓
# Training (60-70%) → Model learns patterns
# ↓
# Validation (15-20%) → Tune hyperparameters, select features
# ↓
# Test (15-20%) → Final, unbiased evaluation (touch once!)
Basic Train-Test Split with Scikit-learn
Scikit-learn’s train_test_split() is your starting point. For a simple two-way split, it handles shuffling and splitting in one function call.
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import pandas as pd
# Load sample data
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = pd.Series(iris.target, name='species')
# Basic 80-20 split
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2, # 20% for testing
random_state=42, # Reproducibility
stratify=y # Maintain class distribution
)
print(f"Training set: {len(X_train)} samples")
print(f"Test set: {len(X_test)} samples")
print(f"\nClass distribution in training:\n{y_train.value_counts(normalize=True)}")
print(f"\nClass distribution in test:\n{y_test.value_counts(normalize=True)}")
The random_state parameter is crucial for reproducibility—use the same value to get identical splits across runs. The stratify parameter ensures each class appears in the same proportion in both sets, critical for imbalanced classification problems.
Creating the Three-Way Split
For train-validation-test splits, you have two practical approaches.
Method 1: Sequential splitting (recommended for clarity):
from sklearn.model_selection import train_test_split
# First split: separate test set (20%)
X_temp, X_test, y_temp, y_test = train_test_split(
X, y,
test_size=0.2,
random_state=42,
stratify=y
)
# Second split: divide remaining into train (75% of 80% = 60%) and validation (25% of 80% = 20%)
X_train, X_val, y_train, y_val = train_test_split(
X_temp, y_temp,
test_size=0.25, # 25% of 80% = 20% of total
random_state=42,
stratify=y_temp
)
print(f"Training: {len(X_train)} ({len(X_train)/len(X)*100:.1f}%)")
print(f"Validation: {len(X_val)} ({len(X_val)/len(X)*100:.1f}%)")
print(f"Test: {len(X_test)} ({len(X_test)/len(X)*100:.1f}%)")
Method 2: Manual numpy splitting (more explicit control):
import numpy as np
# Set seed for reproducibility
np.random.seed(42)
# Shuffle indices
indices = np.random.permutation(len(X))
# Calculate split points
train_end = int(0.6 * len(X))
val_end = int(0.8 * len(X))
# Split indices
train_idx = indices[:train_end]
val_idx = indices[train_end:val_end]
test_idx = indices[val_end:]
# Create splits
X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
X_val, y_val = X.iloc[val_idx], y.iloc[val_idx]
X_test, y_test = X.iloc[test_idx], y.iloc[test_idx]
I recommend Method 1 for most cases—it’s cleaner and leverages scikit-learn’s stratification capabilities.
Time Series Data Splitting
Random splitting destroys temporal relationships and creates data leakage in time series problems. If you train on future data to predict the past, your model’s performance will be artificially inflated and worthless in production.
import pandas as pd
import numpy as np
# Create sample time series data
dates = pd.date_range('2020-01-01', periods=1000, freq='D')
df = pd.DataFrame({
'date': dates,
'value': np.cumsum(np.random.randn(1000)),
'feature': np.random.randn(1000)
})
# Sort by date (crucial!)
df = df.sort_values('date').reset_index(drop=True)
# Sequential split - no shuffling
train_size = int(0.6 * len(df))
val_size = int(0.2 * len(df))
train_df = df[:train_size]
val_df = df[train_size:train_size + val_size]
test_df = df[train_size + val_size:]
print(f"Training period: {train_df['date'].min()} to {train_df['date'].max()}")
print(f"Validation period: {val_df['date'].min()} to {val_df['date'].max()}")
print(f"Test period: {test_df['date'].min()} to {test_df['date'].max()}")
# Verify no temporal overlap
assert train_df['date'].max() < val_df['date'].min()
assert val_df['date'].max() < test_df['date'].min()
For time series, consider using a rolling window approach or time-based cross-validation (TimeSeriesSplit in scikit-learn) for more robust evaluation.
Cross-Validation Alternative
Cross-validation eliminates the need for a fixed validation set by creating multiple train-validation splits. This approach maximizes data utilization and provides more reliable performance estimates.
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.ensemble import RandomForestClassifier
# Prepare data (still hold out test set!)
X_temp, X_test, y_temp, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Create 5-fold cross-validation
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# Train and evaluate model
model = RandomForestClassifier(random_state=42)
cv_scores = cross_val_score(model, X_temp, y_temp, cv=skf, scoring='accuracy')
print(f"Cross-validation scores: {cv_scores}")
print(f"Mean CV accuracy: {cv_scores.mean():.3f} (+/- {cv_scores.std() * 2:.3f})")
# Train final model on all non-test data
model.fit(X_temp, y_temp)
test_score = model.score(X_test, y_test)
print(f"Final test accuracy: {test_score:.3f}")
Use cross-validation when you have limited data or need robust hyperparameter tuning. Use a fixed validation set when you have abundant data or need faster iteration during development.
Best Practices and Common Pitfalls
Always split before preprocessing. This is the most common mistake that leads to data leakage:
from sklearn.preprocessing import StandardScaler
# ❌ WRONG: Scaling before split causes leakage
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X) # Test data statistics leak into training!
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2)
# ✅ CORRECT: Split first, then fit on training only
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train) # Learn from training only
X_test_scaled = scaler.transform(X_test) # Apply training statistics
Handle imbalanced data carefully. Always use stratified splitting for classification:
# For severely imbalanced data
from sklearn.model_selection import train_test_split
# Stratify ensures minority classes appear in all sets
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2,
random_state=42,
stratify=y # Critical for imbalanced datasets
)
# Verify class distribution
print("Training class distribution:")
print(y_train.value_counts(normalize=True))
print("\nTest class distribution:")
print(y_test.value_counts(normalize=True))
Document your splits. Save indices or use consistent random states:
import joblib
# Save split indices for reproducibility
split_indices = {
'train': train_idx.tolist(),
'val': val_idx.tolist(),
'test': test_idx.tolist()
}
joblib.dump(split_indices, 'data_split_indices.pkl')
Never touch test data during development. Treat it as truly unseen data. If you find yourself repeatedly evaluating on the test set and adjusting your model, you’ve turned it into a validation set and compromised your evaluation.
The discipline of proper data splitting separates amateur from professional machine learning practice. Master these techniques, and you’ll build models whose performance metrics you can actually trust.