How to Save and Load Models in PyTorch
PyTorch offers two fundamental methods for persisting models: saving the entire model object or saving just the state dictionary. The distinction matters significantly for production reliability.
Key Insights
- Always save the
state_dictrather than the entire model object—it’s more portable, version-safe, and compatible across PyTorch updates - Include optimizer state, epoch number, and loss in your checkpoints to enable seamless training resumption after interruptions
- Handle device mapping explicitly when loading models to avoid GPU/CPU compatibility issues in production environments
Understanding PyTorch’s Two Saving Approaches
PyTorch offers two fundamental methods for persisting models: saving the entire model object or saving just the state dictionary. The distinction matters significantly for production reliability.
Saving the entire model serializes the complete Python object, including the class definition. This creates a tight coupling between your saved model and the exact code structure that created it. Change your model architecture, refactor your code, or update PyTorch, and you risk breaking compatibility.
The state_dict approach saves only the learned parameters—the weights and biases—as a Python dictionary mapping parameter names to tensors. You maintain the model architecture separately in code, then load the learned parameters into a fresh instance. This separation provides flexibility and robustness.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
model = SimpleNet()
# Approach 1: Save entire model (NOT recommended)
torch.save(model, 'model_complete.pth')
# Approach 2: Save state_dict (RECOMMENDED)
torch.save(model.state_dict(), 'model_state.pth')
The file extensions .pt and .pth are both conventions—PyTorch doesn’t enforce either. Choose one and stick with it across your project.
Saving Models: The Right Way
For production systems, save more than just model weights. Include everything needed to resume training or reproduce results.
def save_checkpoint(model, optimizer, epoch, loss, filepath):
"""Save a complete training checkpoint."""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, filepath)
print(f"Checkpoint saved to {filepath}")
# Usage during training
model = SimpleNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# After training epoch
save_checkpoint(
model=model,
optimizer=optimizer,
epoch=10,
loss=0.234,
filepath='checkpoint_epoch_10.pt'
)
This checkpoint structure enables you to restart training from any saved point without losing momentum. The optimizer state includes accumulated gradients and learning rate schedules—critical for methods like Adam or SGD with momentum.
For inference-only deployments, a minimal save suffices:
# Inference-only save
torch.save({
'model_state_dict': model.state_dict(),
'model_config': {
'input_size': 784,
'hidden_size': 256,
'num_classes': 10
}
}, 'model_inference.pt')
Including configuration parameters in the save file documents the model architecture and enables automated reconstruction.
Loading Models: Complete Workflow
Loading requires two steps: instantiate the model architecture, then populate it with saved weights. This two-phase approach is why state_dict saves are superior—your architecture lives in version-controlled code, not serialized pickle files.
def load_model_for_inference(filepath, device='cpu'):
"""Load model for inference."""
# Step 1: Create model architecture
model = SimpleNet()
# Step 2: Load saved weights
checkpoint = torch.load(filepath, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
# Step 3: Set to evaluation mode
model.eval()
model.to(device)
return model
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model_for_inference('model_inference.pt', device=device)
The map_location parameter is critical. It controls where tensors are loaded. Without it, PyTorch attempts to load tensors to their original device, causing errors if you trained on GPU but deploy on CPU.
For resuming training, restore the complete training state:
def load_checkpoint_for_training(model, optimizer, filepath, device='cpu'):
"""Resume training from checkpoint."""
checkpoint = torch.load(filepath, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
model.to(device)
return model, optimizer, epoch, loss
# Resume training
model = SimpleNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model, optimizer, start_epoch, last_loss = load_checkpoint_for_training(
model, optimizer, 'checkpoint_epoch_10.pt', device=device
)
# Continue training from epoch 11
for epoch in range(start_epoch + 1, 100):
# Training loop
pass
Handling Device Transitions
Moving models between GPU and CPU requires explicit device mapping. Here’s how to handle common scenarios:
# Scenario 1: Trained on GPU, deploy on CPU
checkpoint = torch.load('model_gpu.pt', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
# Scenario 2: Trained on CPU, deploy on GPU
checkpoint = torch.load('model_cpu.pt', map_location=torch.device('cuda:0'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to('cuda:0')
# Scenario 3: Flexible loading with automatic device detection
def smart_load(filepath):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(filepath, map_location=device)
model = SimpleNet()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
return model, device
Transfer Learning and Partial Loading
Transfer learning often requires loading only specific layers. PyTorch’s state_dict makes this straightforward:
def load_pretrained_backbone(model, pretrained_path, freeze=True):
"""Load pretrained weights for specific layers."""
pretrained_dict = torch.load(pretrained_path)
model_dict = model.state_dict()
# Filter out layers you don't want to load (e.g., final classifier)
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict and 'fc2' not in k
}
# Update current model dict
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
# Optionally freeze loaded layers
if freeze:
for name, param in model.named_parameters():
if 'fc2' not in name:
param.requires_grad = False
return model
This pattern is essential for fine-tuning pretrained models on new tasks.
Common Pitfalls and Error Handling
Model loading fails for predictable reasons. Build defensive loading functions:
def safe_load_model(model_class, filepath, device='cpu', strict=True):
"""Robust model loading with error handling."""
try:
model = model_class()
checkpoint = torch.load(filepath, map_location=device)
# Handle both direct state_dict and checkpoint formats
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=strict)
model.to(device)
model.eval()
return model
except FileNotFoundError:
print(f"Model file not found: {filepath}")
raise
except RuntimeError as e:
print(f"Error loading state_dict: {e}")
print("This often means architecture mismatch between saved and current model")
raise
except Exception as e:
print(f"Unexpected error loading model: {e}")
raise
The strict=False parameter allows loading when there are missing or unexpected keys—useful when architectures evolve but you want to preserve compatible weights.
Production Considerations
In production, add versioning and metadata to your saves:
import datetime
import hashlib
def save_versioned_model(model, metadata, save_dir='models'):
"""Save model with version tracking and metadata."""
import os
os.makedirs(save_dir, exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# Create content hash for integrity checking
state_dict_str = str(model.state_dict())
content_hash = hashlib.md5(state_dict_str.encode()).hexdigest()[:8]
checkpoint = {
'model_state_dict': model.state_dict(),
'metadata': {
'timestamp': timestamp,
'content_hash': content_hash,
'pytorch_version': torch.__version__,
**metadata
}
}
filename = f"model_{timestamp}_{content_hash}.pt"
filepath = os.path.join(save_dir, filename)
torch.save(checkpoint, filepath)
return filepath
# Usage
filepath = save_versioned_model(
model,
metadata={
'accuracy': 0.94,
'dataset': 'MNIST',
'training_epochs': 50
}
)
For cloud deployment, integrate with object storage:
def save_to_s3(model, bucket_name, key):
"""Save model directly to S3."""
import boto3
import io
buffer = io.BytesIO()
torch.save(model.state_dict(), buffer)
buffer.seek(0)
s3 = boto3.client('s3')
s3.upload_fileobj(buffer, bucket_name, key)
print(f"Model saved to s3://{bucket_name}/{key}")
Model persistence in PyTorch is straightforward when you follow the state_dict pattern. Save complete checkpoints during training, handle device mapping explicitly, and add versioning for production deployments. These practices ensure your models remain portable, debuggable, and production-ready.