How to Use Mixed Precision Training in PyTorch

Mixed precision training is one of the most effective optimizations you can apply to deep learning workloads. By combining 16-bit floating-point (FP16) and 32-bit floating-point (FP32) computations,...

Key Insights

  • Mixed precision training combines FP16 and FP32 computations to deliver 1.5-3x speedups and 30-50% memory savings with minimal code changes—just wrap your forward pass in autocast and use GradScaler for backpropagation.
  • Loss scaling is critical to prevent gradient underflow when using FP16, and PyTorch’s GradScaler handles this automatically by dynamically adjusting the scale factor throughout training.
  • Not all operations benefit from FP16; PyTorch’s AMP intelligently keeps precision-sensitive operations like batch normalization in FP32 while accelerating matrix multiplications and convolutions in FP16.

Introduction to Mixed Precision Training

Mixed precision training is one of the most effective optimizations you can apply to deep learning workloads. By combining 16-bit floating-point (FP16) and 32-bit floating-point (FP32) computations, you can significantly reduce training time and memory consumption without sacrificing model accuracy.

The performance gains are substantial. On modern NVIDIA GPUs with Tensor Cores (V100, A100, RTX series), you’ll typically see 1.5-3x training speedups. Memory savings range from 30-50%, allowing you to use larger batch sizes or train bigger models on the same hardware. For a ResNet-50 training on ImageNet, this could mean reducing training time from 12 hours to 4-5 hours.

You should use mixed precision training whenever you’re working with GPUs that support it and your model is compute-bound rather than I/O-bound. It’s particularly effective for transformer models, CNNs, and other architectures with heavy matrix operations.

Understanding the Fundamentals

FP32 uses 32 bits to represent each number, providing high precision but consuming more memory and compute resources. FP16 uses only 16 bits, which means faster arithmetic operations and half the memory footprint. However, FP16’s reduced range (approximately 6×10⁻⁸ to 65,504) creates challenges for deep learning.

The primary issue is gradient underflow. During backpropagation, gradients can become extremely small—smaller than FP16 can represent—resulting in them being flushed to zero. This breaks training entirely.

Loss scaling solves this problem by multiplying the loss by a large factor before backpropagation, scaling up the gradients so they remain representable in FP16. After computing gradients, they’re scaled back down before the optimizer step.

PyTorch’s Automatic Mixed Precision (AMP) provides two key components:

  • torch.cuda.amp.autocast: A context manager that automatically selects the appropriate precision for each operation
  • torch.cuda.amp.GradScaler: Handles loss scaling, gradient scaling, and dynamic scale factor adjustment

Here’s a simple comparison of memory usage:

import torch

# FP32 tensor
tensor_fp32 = torch.randn(1000, 1000, dtype=torch.float32, device='cuda')
print(f"FP32 memory: {tensor_fp32.element_size() * tensor_fp32.nelement() / 1024**2:.2f} MB")

# FP16 tensor
tensor_fp16 = torch.randn(1000, 1000, dtype=torch.float16, device='cuda')
print(f"FP16 memory: {tensor_fp16.element_size() * tensor_fp16.nelement() / 1024**2:.2f} MB")

# Output:
# FP32 memory: 3.81 MB
# FP16 memory: 1.91 MB

Basic Implementation with torch.cuda.amp

Converting a standard training loop to use AMP requires minimal changes. Here’s a side-by-side comparison:

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

# Standard training loop (FP32)
def train_standard(model, dataloader, optimizer, criterion):
    model.train()
    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

# AMP training loop (Mixed Precision)
def train_amp(model, dataloader, optimizer, criterion):
    model.train()
    scaler = GradScaler()
    
    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()
        
        optimizer.zero_grad()
        
        # Forward pass with autocast
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

The changes are straightforward:

  1. Create a GradScaler instance
  2. Wrap the forward pass in autocast() context
  3. Scale the loss before calling backward()
  4. Use scaler.step() instead of optimizer.step()
  5. Call scaler.update() to adjust the scale factor

Here’s a complete example with a simple CNN:

import torch.nn.functional as F
from torchvision import datasets, transforms

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc = nn.Linear(128 * 8 * 8, 10)
    
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Setup
model = SimpleCNN().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

# Training loop
for epoch in range(10):
    for inputs, targets in train_loader:
        inputs, targets = inputs.cuda(), targets.cuda()
        
        optimizer.zero_grad()
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

Advanced Techniques and Best Practices

When working with more complex training scenarios, you’ll need additional techniques.

Gradient Clipping with AMP: You must unscale gradients before clipping:

from torch.nn.utils import clip_grad_norm_

scaler = GradScaler()
max_norm = 1.0

for inputs, targets in dataloader:
    optimizer.zero_grad()
    
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)
    
    scaler.scale(loss).backward()
    
    # Unscale before clipping
    scaler.unscale_(optimizer)
    clip_grad_norm_(model.parameters(), max_norm)
    
    scaler.step(optimizer)
    scaler.update()

Multiple Models or Loss Functions: Handle each component carefully:

# Multi-task learning example
scaler = GradScaler()

for inputs, targets_a, targets_b in dataloader:
    optimizer.zero_grad()
    
    with autocast():
        output_a, output_b = model(inputs)
        loss_a = criterion_a(output_a, targets_a)
        loss_b = criterion_b(output_b, targets_b)
        loss = loss_a + 0.5 * loss_b
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Selective Autocast Control: Disable autocast for specific operations:

with autocast():
    # Most operations in FP16
    x = model.backbone(inputs)
    
    # Force FP32 for precision-sensitive operation
    with autocast(enabled=False):
        x = x.float()  # Convert to FP32
        x = custom_precision_sensitive_op(x)
    
    # Back to automatic precision
    output = model.head(x)

Performance Benchmarking and Monitoring

Always measure the actual impact on your specific workload:

import time
import torch.cuda as cuda

def benchmark_training(model, dataloader, use_amp=False, num_iterations=100):
    model.train()
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler() if use_amp else None
    
    # Warmup
    for i, (inputs, targets) in enumerate(dataloader):
        if i >= 10:
            break
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        
        if use_amp:
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
    
    # Benchmark
    cuda.synchronize()
    start_time = time.time()
    start_mem = cuda.memory_allocated()
    
    for i, (inputs, targets) in enumerate(dataloader):
        if i >= num_iterations:
            break
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        
        if use_amp:
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
    
    cuda.synchronize()
    end_time = time.time()
    peak_mem = cuda.max_memory_allocated()
    
    print(f"{'AMP' if use_amp else 'FP32'} - Time: {end_time - start_time:.2f}s, "
          f"Peak Memory: {peak_mem / 1024**3:.2f} GB")
    
    cuda.reset_peak_memory_stats()

# Run benchmarks
benchmark_training(model, train_loader, use_amp=False)
benchmark_training(model, train_loader, use_amp=True)

Common Pitfalls and Troubleshooting

NaN or Inf Losses: Often caused by loss scale being too high or operations producing values outside FP16 range:

def check_gradients(model, scaler):
    """Utility to detect gradient issues"""
    scale = scaler.get_scale()
    print(f"Current loss scale: {scale}")
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad = param.grad
            if torch.isnan(grad).any():
                print(f"NaN gradient in {name}")
            if torch.isinf(grad).any():
                print(f"Inf gradient in {name}")
            grad_norm = grad.norm().item()
            if grad_norm > 1000:
                print(f"Large gradient in {name}: {grad_norm}")

# Use in training loop
scaler.scale(loss).backward()
check_gradients(model, scaler)
scaler.step(optimizer)
scaler.update()

Monitoring Scale Factor Changes: Track when the scaler reduces the scale:

class ScaleMonitor:
    def __init__(self):
        self.prev_scale = None
        self.scale_reductions = 0
    
    def update(self, scaler):
        current_scale = scaler.get_scale()
        if self.prev_scale is not None and current_scale < self.prev_scale:
            self.scale_reductions += 1
            print(f"Scale reduced: {self.prev_scale} -> {current_scale} "
                  f"(total reductions: {self.scale_reductions})")
        self.prev_scale = current_scale

monitor = ScaleMonitor()
# In training loop, after scaler.update():
monitor.update(scaler)

If you see frequent scale reductions, your model may have numerical stability issues that need addressing.

Conclusion and Real-World Impact

Mixed precision training is a production-ready optimization that delivers measurable benefits with minimal implementation effort. The three-line change to your training loop can cut training time in half and reduce memory usage by 40%.

For production deployments, consider these final points:

  • Always benchmark on your specific hardware and workload
  • Monitor validation metrics closely during initial AMP experiments
  • Keep gradient clipping values consistent with your FP32 baseline
  • Use AMP by default for new projects on compatible hardware

Quick-start checklist:

  1. Add from torch.cuda.amp import autocast, GradScaler
  2. Create scaler = GradScaler() before training loop
  3. Wrap forward pass with with autocast():
  4. Replace loss.backward() with scaler.scale(loss).backward()
  5. Replace optimizer.step() with scaler.step(optimizer)
  6. Add scaler.update() after optimizer step
  7. If using gradient clipping, add scaler.unscale_(optimizer) before clipping

Mixed precision training is no longer experimental—it’s standard practice for efficient deep learning. Implement it today and reclaim those GPU hours.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.