Deep Learning: Optimizers Explained
Training a neural network boils down to solving an optimization problem: finding the weights that minimize your loss function. This is harder than it sounds. Neural network loss landscapes are...
Key Insights
- Adam remains the default choice for most deep learning tasks, but SGD with momentum often achieves better generalization on computer vision problems when properly tuned
- Adaptive optimizers like Adam work well out-of-the-box but can get stuck in sharp minima; combining them with learning rate schedules and weight decay (AdamW) significantly improves performance
- The choice of optimizer matters less than proper learning rate tuning—a well-tuned SGD will outperform a poorly configured Adam every time
Introduction to Optimization in Deep Learning
Training a neural network boils down to solving an optimization problem: finding the weights that minimize your loss function. This is harder than it sounds. Neural network loss landscapes are non-convex, riddled with saddle points, local minima, and vast flat regions where gradients vanish. Your optimizer is the algorithm that navigates this treacherous terrain.
The choice of optimizer directly impacts training speed, final model accuracy, and generalization performance. A poor optimizer might take days to converge or get stuck in suboptimal solutions. A good one finds better minima faster and produces models that perform well on unseen data.
Let’s visualize what optimizers are up against:
import numpy as np
import matplotlib.pyplot as plt
# Create a complex loss landscape
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
# Non-convex function with multiple minima and saddle points
Z = (X**2 + Y - 11)**2 + (X + Y**2 - 7)**2 # Himmelblau's function
plt.figure(figsize=(10, 8))
plt.contour(X, Y, Z, levels=50, cmap='viridis')
plt.colorbar(label='Loss')
plt.title('Complex Loss Landscape with Multiple Minima')
plt.xlabel('Weight 1')
plt.ylabel('Weight 2')
plt.show()
This landscape shows why optimization is challenging. Multiple local minima exist, and the path from initialization to a good solution isn’t straightforward.
Gradient Descent Fundamentals
Gradient descent updates weights by moving in the direction opposite to the gradient: w = w - η∇L(w), where η is the learning rate. Three variants exist:
- Batch Gradient Descent: Computes gradients using the entire dataset. Stable but slow and memory-intensive.
- Stochastic Gradient Descent (SGD): Uses one sample at a time. Fast but noisy updates.
- Mini-batch SGD: The practical middle ground, using batches of 32-256 samples.
The learning rate is critical. Too high, and you overshoot minima. Too low, and training crawls.
import numpy as np
class SGD:
def __init__(self, learning_rate=0.01):
self.lr = learning_rate
def update(self, params, grads):
return params - self.lr * grads
# Simple regression example
np.random.seed(42)
X = np.random.randn(1000, 10)
y = X @ np.random.randn(10) + np.random.randn(1000) * 0.1
# Initialize weights
w = np.random.randn(10) * 0.01
batch_size = 32
optimizer = SGD(learning_rate=0.01)
losses = []
for epoch in range(100):
indices = np.random.permutation(len(X))
for i in range(0, len(X), batch_size):
batch_idx = indices[i:i+batch_size]
X_batch, y_batch = X[batch_idx], y[batch_idx]
# Forward pass
pred = X_batch @ w
loss = np.mean((pred - y_batch)**2)
# Backward pass
grad = 2 * X_batch.T @ (pred - y_batch) / len(X_batch)
# Update
w = optimizer.update(w, grad)
losses.append(loss)
if epoch % 20 == 0:
print(f"Epoch {epoch}, Loss: {loss:.4f}")
Plain SGD works but has limitations. It treats all parameters equally and can oscillate in narrow valleys.
Momentum-Based Optimizers
Momentum accumulates past gradients to smooth updates and accelerate convergence. Think of it as a ball rolling downhill, building velocity.
SGD with Momentum: v = βv + ∇L(w) and w = w - ηv, where β (typically 0.9) controls momentum strength.
Nesterov Accelerated Gradient (NAG): Looks ahead by computing gradients at the anticipated position, providing better correction.
class MomentumSGD:
def __init__(self, learning_rate=0.01, momentum=0.9):
self.lr = learning_rate
self.momentum = momentum
self.velocity = None
def update(self, params, grads):
if self.velocity is None:
self.velocity = np.zeros_like(params)
self.velocity = self.momentum * self.velocity + grads
return params - self.lr * self.velocity
# Compare SGD vs Momentum on Rosenbrock function
def rosenbrock(x, y):
return (1 - x)**2 + 100 * (y - x**2)**2
def rosenbrock_grad(x, y):
dx = -2 * (1 - x) - 400 * x * (y - x**2)
dy = 200 * (y - x**2)
return np.array([dx, dy])
# Run both optimizers
start = np.array([-1.0, 1.0])
sgd_path = [start.copy()]
mom_path = [start.copy()]
sgd = SGD(learning_rate=0.001)
momentum = MomentumSGD(learning_rate=0.001, momentum=0.9)
pos_sgd, pos_mom = start.copy(), start.copy()
for _ in range(1000):
grad_sgd = rosenbrock_grad(*pos_sgd)
grad_mom = rosenbrock_grad(*pos_mom)
pos_sgd = sgd.update(pos_sgd, grad_sgd)
pos_mom = momentum.update(pos_mom, grad_mom)
sgd_path.append(pos_sgd.copy())
mom_path.append(pos_mom.copy())
print(f"SGD final position: {pos_sgd}, Loss: {rosenbrock(*pos_sgd):.4f}")
print(f"Momentum final position: {pos_mom}, Loss: {rosenbrock(*pos_mom):.4f}")
Momentum typically converges faster and handles noisy gradients better, especially in regions with different curvatures.
Adaptive Learning Rate Optimizers
Adaptive optimizers adjust learning rates per parameter based on gradient history. This is powerful for models with sparse gradients or varying parameter scales.
AdaGrad: Accumulates squared gradients, reducing learning rates for frequently updated parameters. Works well for sparse data but can be too aggressive, causing premature convergence.
RMSprop: Uses exponential moving average of squared gradients instead of cumulative sum, preventing learning rate decay from becoming too aggressive.
Adam (Adaptive Moment Estimation): Combines momentum with RMSprop. Maintains both first moment (mean) and second moment (variance) of gradients. It’s the most popular optimizer for good reason—it works well with minimal tuning.
AdamW: Adam with decoupled weight decay. Separates L2 regularization from the adaptive learning rate mechanism, improving generalization.
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# Simple CNN for MNIST
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = self.relu(self.fc1(x))
return self.fc2(x)
# Training function
def train_with_optimizer(optimizer_name, lr=0.001, epochs=5):
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
if optimizer_name == 'SGD':
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
elif optimizer_name == 'Adam':
optimizer = optim.Adam(model.parameters(), lr=lr)
elif optimizer_name == 'RMSprop':
optimizer = optim.RMSprop(model.parameters(), lr=lr)
elif optimizer_name == 'AdamW':
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
losses = []
for epoch in range(epochs):
epoch_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx % 100 == 0:
print(f"{optimizer_name} - Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
losses.append(epoch_loss / len(train_loader))
return losses
# Compare optimizers
optimizers = ['SGD', 'Adam', 'RMSprop', 'AdamW']
results = {}
for opt in optimizers:
print(f"\nTraining with {opt}")
results[opt] = train_with_optimizer(opt)
# Plot results
plt.figure(figsize=(10, 6))
for opt, losses in results.items():
plt.plot(losses, label=opt, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('Optimizer Comparison on MNIST')
plt.legend()
plt.grid(True)
plt.show()
In practice, Adam converges fastest initially, but SGD with momentum often achieves better final test accuracy on vision tasks.
Advanced Optimizers and Recent Developments
Recent research has produced specialized optimizers for specific scenarios:
RAdam (Rectified Adam): Addresses Adam’s poor early-stage convergence by applying a rectification term to the adaptive learning rate. Particularly useful when you can’t afford extensive warmup.
AdaBound: Transitions from Adam-like behavior to SGD-like behavior as training progresses, combining fast initial convergence with better final generalization.
Lookahead: A meta-optimizer that wraps other optimizers, maintaining slow and fast weights to improve stability.
LAMB (Layer-wise Adaptive Moments optimizer for Batch training): Designed for large batch training, enabling batch sizes of 32K+ without loss of accuracy. Critical for distributed training.
# RAdam vs Adam comparison (requires pytorch-optimizer)
# pip install torch-optimizer
import torch_optimizer as custom_optim
def compare_adam_variants():
model_adam = SimpleCNN()
model_radam = SimpleCNN()
# Copy initial weights
model_radam.load_state_dict(model_adam.state_dict())
optimizer_adam = optim.Adam(model_adam.parameters(), lr=0.001)
optimizer_radam = custom_optim.RAdam(model_radam.parameters(), lr=0.001)
# Training loop similar to above
# RAdam typically shows more stable early training
return model_adam, model_radam
# For transformers, RAdam often outperforms Adam
# especially without learning rate warmup
Choosing the Right Optimizer
Here’s practical guidance:
Use Adam/AdamW when:
- Starting a new project (best default choice)
- Working with RNNs, transformers, or NLP models
- You need fast prototyping
- Dealing with sparse gradients
Use SGD with momentum when:
- Training CNNs for production (ResNets, EfficientNets)
- You have time for hyperparameter tuning
- Final test accuracy matters more than training speed
- Working with small batch sizes
Use RMSprop when:
- Training RNNs (though Adam has largely replaced it)
- Dealing with non-stationary objectives
Use specialized optimizers when:
- LAMB: Training with very large batches (>8K)
- RAdam: You can’t use learning rate warmup
- AdaBound: You want Adam’s speed with SGD’s generalization
Hyperparameter tuning example:
import optuna
def objective(trial):
# Suggest optimizer
optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'AdamW'])
lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
if optimizer_name == 'SGD':
momentum = trial.suggest_uniform('momentum', 0.8, 0.99)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
elif optimizer_name == 'Adam':
optimizer = optim.Adam(model.parameters(), lr=lr)
else:
weight_decay = trial.suggest_loguniform('weight_decay', 1e-5, 1e-2)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# Train and return validation accuracy
val_acc = train_and_evaluate(model, optimizer)
return val_acc
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)
print(f"Best optimizer config: {study.best_params}")
Conclusion and Best Practices
The optimizer landscape is mature. Adam/AdamW handles 80% of use cases effectively. SGD with momentum remains king for computer vision when properly tuned. Don’t chase exotic optimizers unless you have specific needs.
Critical best practices:
- Always use learning rate scheduling: Cosine annealing, step decay, or ReduceLROnPlateau dramatically improve results
- Apply gradient clipping: Prevents exploding gradients, especially in RNNs. Use
torch.nn.utils.clip_grad_norm_ - Use weight decay properly: With Adam, use AdamW or decouple weight decay from the optimizer
- Monitor gradient norms: If they explode or vanish, adjust learning rate or architecture
- Warm up learning rates: Start with low learning rate for first few epochs, especially with large batch sizes
The difference between a mediocre model and a great one often comes down to optimization details. Master these fundamentals, and you’ll train better models faster.