How to Implement Multi-GPU Training in PyTorch
Training deep learning models on multiple GPUs isn't just about throwing more hardware at the problem—it's a necessity when working with large models or datasets that won't fit in a single GPU's...
Key Insights
- DistributedDataParallel (DDP) is the recommended approach for multi-GPU training in PyTorch, offering better performance and scalability than DataParallel while being more accessible than FSDP for most use cases
- Multi-GPU training requires careful attention to batch size scaling, learning rate adjustment, and proper process synchronization—simply wrapping your model isn’t enough for optimal performance
- The speedup from multi-GPU training is rarely linear; expect 1.7-1.8x speedup with 2 GPUs and 3.2-3.5x with 4 GPUs due to communication overhead and synchronization costs
Introduction to Multi-GPU Training
Training deep learning models on multiple GPUs isn’t just about throwing more hardware at the problem—it’s a necessity when working with large models or datasets that won’t fit in a single GPU’s memory. A modern transformer model with billions of parameters can easily exceed 24GB of VRAM, making multi-GPU training essential rather than optional.
The performance benefits are substantial but not magical. With proper implementation, you can expect near-linear scaling up to 4 GPUs on a single node, though communication overhead means you’ll rarely see perfect 4x speedup. Beyond that, you’re dealing with multi-node training, which introduces network bottlenecks.
Multi-GPU training is overkill if your model trains in reasonable time on a single GPU and fits comfortably in memory. Don’t optimize prematurely—but when your training runs take days instead of hours, it’s time to scale horizontally.
PyTorch Multi-GPU Strategies Overview
PyTorch offers three main approaches to multi-GPU training, each with distinct trade-offs:
DataParallel (DP) is the simplest but least efficient option. It replicates your model across GPUs but uses a single process with threading, creating a bottleneck on GPU 0. It’s essentially deprecated for serious work.
DistributedDataParallel (DDP) runs separate processes for each GPU, eliminating the single-process bottleneck. It’s the recommended approach for most multi-GPU scenarios and what we’ll focus on.
Fully Sharded Data Parallel (FSDP) shards model parameters, gradients, and optimizer states across GPUs, enabling training of models too large for any single GPU. Use this for models with tens of billions of parameters.
Here’s what the basic syntax looks like for each:
# DataParallel (don't use this)
model = nn.DataParallel(model)
model.cuda()
# DistributedDataParallel (recommended)
model = model.to(local_rank)
model = DDP(model, device_ids=[local_rank])
# FSDP (for massive models)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model)
Setting Up DistributedDataParallel
DDP requires initializing a process group where each process corresponds to one GPU. Each process needs to know its rank (unique ID) and the world size (total number of processes).
First, initialize the distributed environment:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# Initialize process group
dist.init_process_group(
backend='nccl', # Use 'nccl' for GPU, 'gloo' for CPU
rank=rank,
world_size=world_size
)
# Set the device for this process
torch.cuda.set_device(rank)
def cleanup():
"""Clean up the distributed environment."""
dist.destroy_process_group()
Now wrap your model with DDP:
def create_model(rank):
"""Create and wrap model with DDP."""
model = YourModel().to(rank)
# Wrap with DDP
ddp_model = DDP(model, device_ids=[rank])
return ddp_model
The key difference from single-GPU training is that each process gets its own model replica on its assigned GPU. DDP handles gradient synchronization automatically during the backward pass.
Adapting Your Training Loop
Your training loop needs three main modifications: distributed sampling, proper device placement, and rank-aware checkpointing.
Use DistributedSampler to ensure each process sees different data:
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
def create_dataloader(dataset, rank, world_size, batch_size):
"""Create dataloader with distributed sampling."""
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True
)
return dataloader, sampler
Here’s a complete training loop:
def train(rank, world_size):
setup(rank, world_size)
# Create model and move to device
model = create_model(rank)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Create dataloader
train_loader, sampler = create_dataloader(
train_dataset, rank, world_size, batch_size=32
)
for epoch in range(num_epochs):
# Set epoch for proper shuffling
sampler.set_epoch(epoch)
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# Move data to the correct device
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward() # Gradients automatically synchronized
optimizer.step()
if rank == 0 and batch_idx % 100 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')
# Save checkpoint from rank 0 only
if rank == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'checkpoint_epoch_{epoch}.pt')
cleanup()
Critical point: only save checkpoints from rank 0 to avoid file conflicts. Access the underlying model with model.module.state_dict() when using DDP.
Launching Multi-GPU Training
Use torchrun (recommended) or the older torch.distributed.launch:
# Using torchrun (PyTorch 1.10+)
torchrun --nproc_per_node=4 train.py
# Using torch.distributed.launch (older)
python -m torch.distributed.launch --nproc_per_node=4 train.py
Your script’s entry point should look like this:
import torch.multiprocessing as mp
def main():
world_size = torch.cuda.device_count()
mp.spawn(
train,
args=(world_size,),
nprocs=world_size,
join=True
)
if __name__ == '__main__':
main()
Alternatively, when using torchrun, access environment variables directly:
if __name__ == '__main__':
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
train(rank, world_size)
Common Pitfalls and Best Practices
Batch size scaling: Your effective batch size is batch_size * world_size. If you use batch_size=32 with 4 GPUs, your effective batch size is 128. Scale your learning rate accordingly—a common rule is to multiply the learning rate by the number of GPUs, though this isn’t always optimal.
Gradient accumulation: Combine with DDP for even larger effective batch sizes:
accumulation_steps = 4
for i, (data, target) in enumerate(train_loader):
output = model(data.to(rank))
loss = criterion(output, target.to(rank)) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Debugging: Set TORCH_DISTRIBUTED_DEBUG=DETAIL to catch synchronization issues. Use torch.distributed.barrier() to ensure all processes reach the same point.
Monitor GPU utilization to catch imbalances:
import nvidia_smi
def monitor_gpus():
nvidia_smi.nvmlInit()
device_count = nvidia_smi.nvmlDeviceGetCount()
for i in range(device_count):
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
print(f"GPU {i}: {info.used / 1024**2:.0f}MB used, {util.gpu}% utilized")
Performance Benchmarking
Measure actual speedup to validate your implementation:
import time
from contextlib import contextmanager
@contextmanager
def timer(name):
start = time.time()
yield
end = time.time()
if dist.get_rank() == 0:
print(f"{name}: {end - start:.2f} seconds")
# In your training loop
with timer(f"Epoch {epoch}"):
for batch in train_loader:
# training code
pass
Calculate throughput:
def calculate_throughput(total_samples, elapsed_time, world_size):
"""Samples processed per second across all GPUs."""
throughput = total_samples / elapsed_time
print(f"Throughput: {throughput:.2f} samples/sec")
print(f"Per-GPU throughput: {throughput / world_size:.2f} samples/sec")
return throughput
Expect 1.7-1.8x speedup with 2 GPUs, 3.2-3.5x with 4 GPUs on a single node. If you’re seeing significantly less, check for CPU bottlenecks in data loading (increase num_workers), small batch sizes (increase to amortize communication overhead), or model architecture issues (excessive small operations that don’t parallelize well).
Multi-GPU training is powerful but requires thoughtful implementation. Start with DDP, monitor your metrics, and scale your hyperparameters appropriately. The performance gains are real, but only if you avoid the common pitfalls.