How to Fine-Tune Pretrained Models in PyTorch
Transfer learning is the practice of taking a model trained on one task and adapting it to a related task. Fine-tuning specifically refers to continuing the training process on your custom dataset...
Key Insights
- Fine-tuning pretrained models reduces training time by 10-100x compared to training from scratch while often achieving better performance with less data
- The key to successful fine-tuning is using lower learning rates (typically 10-100x smaller) for pretrained layers than for newly initialized layers to preserve learned features
- Always match your preprocessing pipeline exactly to what the pretrained model expects—mismatched normalization statistics are the most common cause of poor transfer learning results
Introduction to Transfer Learning and Fine-Tuning
Transfer learning is the practice of taking a model trained on one task and adapting it to a related task. Fine-tuning specifically refers to continuing the training process on your custom dataset while starting from pretrained weights rather than random initialization.
The advantages are substantial. A ResNet50 trained on ImageNet from scratch requires days on multiple GPUs and 1.2 million labeled images. Fine-tuning that same architecture for your custom task might take hours on a single GPU with just thousands of images. The pretrained model has already learned fundamental features—edges, textures, shapes for vision models, or syntactic and semantic patterns for language models.
Common use cases include adapting ImageNet-trained vision models for medical imaging, satellite imagery analysis, or product classification, and using BERT or GPT models for domain-specific text classification, named entity recognition, or question answering.
Loading a Pretrained Model
PyTorch makes loading pretrained models straightforward through torchvision.models, Hugging Face’s transformers library, or torch.hub. Let’s focus on computer vision with torchvision.
import torch
import torch.nn as nn
from torchvision import models
# Load pretrained ResNet50
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# Inspect the architecture
print(model)
# Examine the final layer
print(f"Original classifier: {model.fc}")
print(f"Input features to FC layer: {model.fc.in_features}")
print(f"Output classes: {model.fc.out_features}")
The model architecture shows ResNet50’s structure: an initial conv layer, followed by four residual layer groups (layer1-4), adaptive average pooling, and a final fully connected layer. The FC layer has 2048 input features and 1000 output classes (ImageNet categories).
Understanding this structure is critical. The early layers learn generic features (edges, colors), middle layers learn patterns (textures, parts), and late layers learn task-specific features. This hierarchy informs our freezing strategy.
Preparing Your Custom Dataset
Your custom dataset must match the preprocessing expectations of the pretrained model. For ImageNet-trained models, this means specific normalization statistics.
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class CustomImageDataset(Dataset):
def __init__(self, image_dir, labels, transform=None):
"""
Args:
image_dir: Directory with all images
labels: List of (filename, label_idx) tuples
transform: Optional transform to apply
"""
self.image_dir = image_dir
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_name, label = self.labels[idx]
img_path = os.path.join(self.image_dir, img_name)
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
# Critical: Use ImageNet normalization stats
train_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Create datasets and dataloaders
train_labels = [("img1.jpg", 0), ("img2.jpg", 1), ...] # Your data
train_dataset = CustomImageDataset("./data/train", train_labels, train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataset = CustomImageDataset("./data/val", val_labels, val_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
The normalization values [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225] are ImageNet statistics. Using different values will cause a distribution shift that degrades performance.
Freezing and Unfreezing Layers
The standard approach is to freeze pretrained layers initially and only train the new classification head. This prevents catastrophic forgetting where large gradients destroy learned features.
# Replace the final layer for your number of classes
num_classes = 10 # Your task
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Freeze all layers except the final classifier
for param in model.parameters():
param.requires_grad = False
# Unfreeze only the final layer
for param in model.fc.parameters():
param.requires_grad = True
# Verify what's trainable
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable_params:,} / {total_params:,} parameters")
After training the head for a few epochs, you can optionally unfreeze deeper layers for fine-tuning:
# Unfreeze the last residual block (layer4)
for param in model.layer4.parameters():
param.requires_grad = True
Training Loop and Optimization
Use different learning rates for pretrained and new layers. The pretrained layers need small updates to preserve learned features, while the new classifier needs larger updates to learn the task.
import torch.optim as optim
# Discriminative learning rates
optimizer = optim.Adam([
{'params': model.layer4.parameters(), 'lr': 1e-4}, # Pretrained layers
{'params': model.fc.parameters(), 'lr': 1e-3} # New classifier head
])
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def train_epoch(model, train_loader, optimizer, criterion, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / len(train_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def validate(model, val_loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_loss = running_loss / len(val_loader)
val_acc = 100. * correct / total
return val_loss, val_acc
# Training loop
num_epochs = 10
best_val_acc = 0.0
for epoch in range(num_epochs):
train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
val_loss, val_acc = validate(model, val_loader, criterion, device)
print(f"Epoch {epoch+1}/{num_epochs}")
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
}, 'best_model.pth')
Evaluation and Common Pitfalls
Beyond accuracy, evaluate with metrics appropriate for your task. For imbalanced datasets, use F1-score, precision, and recall.
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
def evaluate_model(model, test_loader, device, class_names):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in test_loader:
inputs = inputs.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
# Classification report
print(classification_report(all_labels, all_preds, target_names=class_names))
# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix:")
print(cm)
return np.array(all_preds), np.array(all_labels)
Common pitfalls to avoid:
- Wrong preprocessing: Not using the same normalization as pretraining causes immediate performance degradation
- Learning rate too high: Values above 1e-3 for pretrained layers often destroy learned features
- Not freezing initially: Training all layers from the start can lead to catastrophic forgetting
- Batch size mismatch: Very small batches (< 8) can cause unstable batch normalization statistics
Advanced Techniques
For even better results, implement progressive unfreezing with learning rate scheduling:
from torch.optim.lr_scheduler import CosineAnnealingLR
# Start with only the head trainable
for param in model.parameters():
param.requires_grad = False
for param in model.fc.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=10)
# After 3 epochs, unfreeze layer4
def progressive_unfreeze(model, epoch):
if epoch == 3:
for param in model.layer4.parameters():
param.requires_grad = True
# Add these params to optimizer
optimizer.add_param_group({'params': model.layer4.parameters(), 'lr': 1e-4})
elif epoch == 6:
for param in model.layer3.parameters():
param.requires_grad = True
optimizer.add_param_group({'params': model.layer3.parameters(), 'lr': 5e-5})
# In training loop
for epoch in range(num_epochs):
progressive_unfreeze(model, epoch)
train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
scheduler.step()
Fine-tuning pretrained models is the most practical approach for most deep learning tasks. Start conservative with frozen layers and low learning rates, then gradually increase model flexibility as needed. With proper technique, you’ll achieve strong results with a fraction of the data and compute required for training from scratch.