How to Use Optimizers in PyTorch
Optimizers are the engines that drive neural network training. They implement algorithms that adjust model parameters to minimize the loss function through variants of gradient descent. In PyTorch,...
Key Insights
- Adam and AdamW are the default choices for most deep learning tasks, but SGD with momentum often achieves better generalization on computer vision problems when properly tuned
- The learning rate is the most critical hyperparameter—start with 1e-3 for Adam, 1e-1 for SGD, and always use learning rate scheduling for production models
- Forgetting
optimizer.zero_grad()before backpropagation is the most common bug that causes incorrect gradient accumulation and failed training
Introduction to Optimizers in Deep Learning
Optimizers are the engines that drive neural network training. They implement algorithms that adjust model parameters to minimize the loss function through variants of gradient descent. In PyTorch, the torch.optim module provides implementations of all major optimization algorithms, from basic Stochastic Gradient Descent (SGD) to sophisticated adaptive methods like Adam.
The optimizer you choose directly impacts training speed, final model performance, and convergence stability. A poorly chosen optimizer or misconfigured hyperparameters can mean the difference between a model that trains in hours versus days, or one that converges versus one that diverges entirely.
Basic Optimizer Setup
Every PyTorch optimizer follows the same pattern: instantiate it with your model’s parameters, then use it in your training loop to update weights after computing gradients.
import torch
import torch.nn as nn
import torch.optim as optim
# Define a simple neural network
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return self.fc3(x)
# Instantiate model and optimizer
model = SimpleNet()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Basic training loop structure
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
# Update weights
optimizer.step()
The critical sequence is: zero gradients, compute loss, backpropagate, update parameters. This pattern remains consistent regardless of which optimizer you use.
Common Optimizer Types and When to Use Them
PyTorch provides several optimizers, each with different characteristics:
SGD uses the basic gradient descent update rule, optionally with momentum. It’s simple, memory-efficient, and often generalizes better than adaptive methods on vision tasks, but requires careful learning rate tuning.
Adam adapts learning rates per parameter using estimates of first and second moments of gradients. It’s robust to hyperparameter choices and works well across diverse problems, making it the default choice for most practitioners.
AdamW is Adam with decoupled weight decay regularization. Use this instead of Adam for better regularization, especially with transformers.
RMSprop adapts learning rates using a moving average of squared gradients. It was popular before Adam but is now mostly used for recurrent networks.
Adagrad adapts learning rates based on historical gradients. It’s rarely used in deep learning because learning rates can decay too aggressively.
# Initialize different optimizers with the same model
model = SimpleNet()
# SGD with momentum - good for computer vision
optimizer_sgd = optim.SGD(
model.parameters(),
lr=0.1, # Higher LR than Adam
momentum=0.9,
weight_decay=1e-4
)
# Adam - general purpose default
optimizer_adam = optim.Adam(
model.parameters(),
lr=1e-3, # Standard starting point
betas=(0.9, 0.999),
eps=1e-8
)
# AdamW - better weight decay, great for transformers
optimizer_adamw = optim.AdamW(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999),
weight_decay=0.01 # Decoupled weight decay
)
# RMSprop - often used for RNNs
optimizer_rmsprop = optim.RMSprop(
model.parameters(),
lr=1e-3,
alpha=0.99,
momentum=0.0
)
For new projects, start with AdamW. Switch to SGD with momentum if you’re working on computer vision and have time to tune hyperparameters.
Key Optimizer Parameters
Learning rate controls step size during optimization. Too high causes divergence; too low causes slow convergence. This is the most important hyperparameter to tune. Use 1e-3 for Adam/AdamW, 1e-1 for SGD as starting points.
Momentum (SGD) accumulates gradients over iterations, helping overcome local minima and speeding convergence. Values between 0.9 and 0.99 work well.
Betas (Adam/AdamW) control exponential decay rates for moment estimates. The defaults (0.9, 0.999) work for most cases. Use (0.9, 0.98) for transformers on some tasks.
Weight decay adds L2 regularization to prevent overfitting. Typical values range from 1e-5 to 1e-2. Note that weight decay behaves differently in Adam versus AdamW.
import matplotlib.pyplot as plt
# Demonstrate learning rate impact
def train_with_lr(model, optimizer, train_loader, epochs=10):
losses = []
for epoch in range(epochs):
epoch_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
losses.append(epoch_loss / len(train_loader))
return losses
# Compare different learning rates
learning_rates = [1e-4, 1e-3, 1e-2, 1e-1]
results = {}
for lr in learning_rates:
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=lr)
losses = train_with_lr(model, optimizer, train_loader)
results[lr] = losses
print(f"LR {lr}: Final loss = {losses[-1]:.4f}")
Learning Rate Scheduling
Static learning rates are suboptimal. Learning rate schedulers adjust the learning rate during training to improve convergence and final performance.
StepLR reduces learning rate by a factor every N epochs. Simple and effective for many tasks.
ReduceLROnPlateau reduces learning rate when a metric plateaus. Useful when you don’t know the optimal schedule in advance.
CosineAnnealingLR follows a cosine curve, popular in modern architectures like ResNets and transformers.
# Complete training loop with scheduler
model = SimpleNet()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# Reduce LR by 0.1 every 30 epochs
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# Alternative: reduce on plateau
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(
# optimizer, mode='min', factor=0.1, patience=10
# )
num_epochs = 100
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
val_loss += criterion(output, target).item()
# Step scheduler
scheduler.step()
# For ReduceLROnPlateau: scheduler.step(val_loss)
# Log current learning rate
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, "
f"Val Loss = {val_loss:.4f}, LR = {current_lr:.6f}")
Always use scheduling for production models. It’s one of the easiest ways to improve performance.
Advanced Techniques
Gradient clipping prevents exploding gradients, especially important for RNNs and transformers:
# Clip gradients by norm
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Multiple optimizers are necessary when different parts of your model need different update rules, common in GANs:
class Generator(nn.Module):
# ... generator architecture ...
pass
class Discriminator(nn.Module):
# ... discriminator architecture ...
pass
generator = Generator()
discriminator = Discriminator()
# Separate optimizers for each network
optimizer_G = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
# Training loop
for real_data, _ in dataloader:
# Train discriminator
optimizer_D.zero_grad()
# ... discriminator forward/backward ...
optimizer_D.step()
# Train generator
optimizer_G.zero_grad()
# ... generator forward/backward ...
optimizer_G.step()
Saving and loading optimizer state is crucial for resuming training:
# Save checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
Best Practices and Common Pitfalls
Always call zero_grad() before backpropagation. Forgetting this accumulates gradients across batches, which is rarely what you want:
# WRONG - gradients accumulate
for data, target in train_loader:
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step() # Missing zero_grad()!
# CORRECT
for data, target in train_loader:
optimizer.zero_grad() # Clear gradients
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
Step schedulers at the right time. Most schedulers should be called once per epoch, after the optimizer step:
# WRONG - scheduler before optimizer
for epoch in range(num_epochs):
scheduler.step() # Too early!
for data, target in train_loader:
# ... training ...
optimizer.step()
# CORRECT
for epoch in range(num_epochs):
for data, target in train_loader:
# ... training ...
optimizer.step()
scheduler.step() # After all batches
Start with proven defaults: AdamW with lr=1e-3, weight_decay=0.01, and cosine annealing scheduler. Only deviate when you have a specific reason.
Monitor learning rates during training. Log them to ensure schedulers are working correctly.
Use gradient clipping if you see NaN losses or exploding gradients.
The optimizer is the most critical component of your training pipeline after the model architecture itself. Master these patterns, and you’ll train models more effectively and debug issues faster.