How to Use Transfer Learning in PyTorch
Transfer learning is the practice of taking a model trained on one task and adapting it to a related task. Instead of training a deep neural network from scratch—which requires massive datasets and...
Key Insights
- Transfer learning lets you leverage pre-trained models to achieve high accuracy with minimal data and training time—often reducing training from days to hours
- Feature extraction (freezing all layers) works best when your dataset is small and similar to the pre-training data, while fine-tuning (unfreezing layers) is better for larger datasets or different domains
- Using differential learning rates when fine-tuning—lower rates for early layers, higher for later ones—preserves valuable low-level features while adapting high-level representations to your specific task
Introduction to Transfer Learning
Transfer learning is the practice of taking a model trained on one task and adapting it to a related task. Instead of training a deep neural network from scratch—which requires massive datasets and computational resources—you start with weights learned from a large-scale dataset like ImageNet and fine-tune them for your specific problem.
This approach is incredibly valuable because deep learning models learn hierarchical features. Early layers capture universal patterns like edges and textures, while deeper layers learn task-specific features. Since those early-layer features transfer well across domains, you can reuse them and only retrain the final layers for your specific classification task.
Transfer learning shines when you have limited labeled data. Training a ResNet-50 from scratch might require millions of images, but with transfer learning, you can achieve excellent results with just thousands—or even hundreds—of examples. It’s the standard approach for computer vision tasks, medical imaging, satellite imagery analysis, and any domain where collecting massive labeled datasets is impractical.
Setting Up Your Environment
First, ensure you have PyTorch and torchvision installed. Torchvision provides pre-trained models and utilities for computer vision tasks.
pip install torch torchvision matplotlib
Let’s start by importing the necessary libraries and loading a pre-trained model:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import copy
# Check for GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load a pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)
print(f"Model loaded with {sum(p.numel() for p in model.parameters())} parameters")
The pretrained=True flag downloads weights trained on ImageNet’s 1.4 million images across 1000 classes. This model already understands a rich variety of visual features.
Feature Extraction Approach
Feature extraction treats the pre-trained model as a fixed feature extractor. You freeze all convolutional layers and only train the final classifier layer. This is fast, requires less data, and works well when your task is similar to the original training task.
Here’s how to implement feature extraction with ResNet-50:
# Load pre-trained ResNet-50
model = models.resnet50(pretrained=True)
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
# Replace the final fully connected layer
# ResNet-50's final layer is called 'fc'
num_features = model.fc.in_features
num_classes = 10 # Your number of classes
model.fc = nn.Linear(num_features, num_classes)
# Move model to device
model = model.to(device)
# Only the final layer parameters require gradients
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params}")
When you freeze parameters with requires_grad = False, PyTorch won’t compute gradients for those layers during backpropagation. This dramatically reduces memory usage and speeds up training. Only the new final layer (2048 → 10 in this case) will be trained.
The criterion and optimizer setup focuses only on trainable parameters:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
Fine-Tuning Approach
Fine-tuning unfreezes some or all layers and trains them with a low learning rate. This adapts the pre-trained features to your specific domain, which is essential when your data differs significantly from ImageNet or when you have sufficient training data.
The key insight is using differential learning rates: lower rates for early layers (to preserve learned features) and higher rates for later layers (to adapt to your task).
# Load pre-trained model
model = models.resnet50(pretrained=True)
# Replace final layer
num_features = model.fc.in_features
num_classes = 10
model.fc = nn.Linear(num_features, num_classes)
model = model.to(device)
# Unfreeze all parameters (they're already unfrozen by default, but being explicit)
for param in model.parameters():
param.requires_grad = True
# Create parameter groups with different learning rates
params_to_update = [
{"params": model.layer4.parameters(), "lr": 1e-4},
{"params": model.layer3.parameters(), "lr": 5e-5},
{"params": model.fc.parameters(), "lr": 1e-3}
]
optimizer = optim.Adam(params_to_update)
criterion = nn.CrossEntropyLoss()
This configuration trains the final classifier layer 10x faster than layer4 and 20x faster than layer3. You can also selectively freeze early layers entirely:
# Freeze early layers
for param in model.layer1.parameters():
param.requires_grad = False
for param in model.layer2.parameters():
param.requires_grad = False
Training and Evaluation
Proper data preparation is critical. Use data augmentation to artificially expand your training set and improve generalization:
# Data transforms
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
# Load datasets (assuming ImageFolder structure)
data_dir = 'path/to/your/data'
image_datasets = {
x: datasets.ImageFolder(f"{data_dir}/{x}", data_transforms[x])
for x in ['train', 'val']
}
dataloaders = {
x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
for x in ['train', 'val']
}
Here’s a complete training loop with validation:
def train_model(model, criterion, optimizer, num_epochs=25):
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(image_datasets[phase])
epoch_acc = running_corrects.double() / len(image_datasets[phase])
history[f'{phase}_loss'].append(epoch_loss)
history[f'{phase}_acc'].append(epoch_acc.item())
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
model.load_state_dict(best_model_wts)
return model, history
# Train the model
model, history = train_model(model, criterion, optimizer, num_epochs=20)
Practical Tips and Best Practices
Choose the right approach: Use feature extraction when you have less than 1,000 images per class or when your domain is similar to ImageNet. Use fine-tuning when you have more data or when your domain differs significantly (medical images, satellite imagery, etc.).
Learning rate selection: Start with 0.001 for new layers and 0.0001 for fine-tuning pre-trained layers. Use learning rate schedulers to reduce rates as training progresses:
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
Gradual unfreezing: Instead of unfreezing all layers at once, try progressive unfreezing. Train the classifier first, then unfreeze the last ResNet block, train more, then unfreeze earlier blocks. This prevents catastrophic forgetting of learned features.
Batch normalization layers: When fine-tuning, decide whether to freeze batch normalization layers. Generally, keep them in eval mode if your batch size is small:
model.train()
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
module.eval()
Monitor for overfitting: Watch the validation loss. If it increases while training loss decreases, you’re overfitting. Increase data augmentation, add dropout, or reduce model capacity.
Save checkpoints: Always save your best model based on validation accuracy, not training accuracy:
torch.save(model.state_dict(), 'best_model.pth')
Conclusion
Transfer learning is your most powerful tool for practical deep learning applications. By leveraging pre-trained models, you can build production-ready image classifiers with limited data and computational resources. Start with feature extraction for quick experiments, then progress to fine-tuning as you gather more data and refine your approach.
The choice between feature extraction and fine-tuning depends on your dataset size and domain similarity to ImageNet. When in doubt, start with feature extraction—it’s faster and less prone to overfitting. As you collect more data, transition to fine-tuning with differential learning rates to squeeze out additional performance.
Remember that transfer learning isn’t magic. It works best when the source and target domains share visual similarities. For highly specialized domains with no ImageNet overlap, you might need domain-specific pre-trained models or more extensive fine-tuning. But for the vast majority of computer vision tasks, starting with ImageNet weights will give you a significant head start.