How to Save and Load Models in PyTorch

PyTorch offers two fundamental methods for persisting models: saving the entire model object or saving just the state dictionary. The distinction matters significantly for production reliability.

Key Insights

  • Always save the state_dict rather than the entire model object—it’s more portable, version-safe, and compatible across PyTorch updates
  • Include optimizer state, epoch number, and loss in your checkpoints to enable seamless training resumption after interruptions
  • Handle device mapping explicitly when loading models to avoid GPU/CPU compatibility issues in production environments

Understanding PyTorch’s Two Saving Approaches

PyTorch offers two fundamental methods for persisting models: saving the entire model object or saving just the state dictionary. The distinction matters significantly for production reliability.

Saving the entire model serializes the complete Python object, including the class definition. This creates a tight coupling between your saved model and the exact code structure that created it. Change your model architecture, refactor your code, or update PyTorch, and you risk breaking compatibility.

The state_dict approach saves only the learned parameters—the weights and biases—as a Python dictionary mapping parameter names to tensors. You maintain the model architecture separately in code, then load the learned parameters into a fresh instance. This separation provides flexibility and robustness.

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleNet()

# Approach 1: Save entire model (NOT recommended)
torch.save(model, 'model_complete.pth')

# Approach 2: Save state_dict (RECOMMENDED)
torch.save(model.state_dict(), 'model_state.pth')

The file extensions .pt and .pth are both conventions—PyTorch doesn’t enforce either. Choose one and stick with it across your project.

Saving Models: The Right Way

For production systems, save more than just model weights. Include everything needed to resume training or reproduce results.

def save_checkpoint(model, optimizer, epoch, loss, filepath):
    """Save a complete training checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to {filepath}")

# Usage during training
model = SimpleNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# After training epoch
save_checkpoint(
    model=model,
    optimizer=optimizer,
    epoch=10,
    loss=0.234,
    filepath='checkpoint_epoch_10.pt'
)

This checkpoint structure enables you to restart training from any saved point without losing momentum. The optimizer state includes accumulated gradients and learning rate schedules—critical for methods like Adam or SGD with momentum.

For inference-only deployments, a minimal save suffices:

# Inference-only save
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'input_size': 784,
        'hidden_size': 256,
        'num_classes': 10
    }
}, 'model_inference.pt')

Including configuration parameters in the save file documents the model architecture and enables automated reconstruction.

Loading Models: Complete Workflow

Loading requires two steps: instantiate the model architecture, then populate it with saved weights. This two-phase approach is why state_dict saves are superior—your architecture lives in version-controlled code, not serialized pickle files.

def load_model_for_inference(filepath, device='cpu'):
    """Load model for inference."""
    # Step 1: Create model architecture
    model = SimpleNet()
    
    # Step 2: Load saved weights
    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Step 3: Set to evaluation mode
    model.eval()
    model.to(device)
    
    return model

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model_for_inference('model_inference.pt', device=device)

The map_location parameter is critical. It controls where tensors are loaded. Without it, PyTorch attempts to load tensors to their original device, causing errors if you trained on GPU but deploy on CPU.

For resuming training, restore the complete training state:

def load_checkpoint_for_training(model, optimizer, filepath, device='cpu'):
    """Resume training from checkpoint."""
    checkpoint = torch.load(filepath, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    model.train()
    model.to(device)
    
    return model, optimizer, epoch, loss

# Resume training
model = SimpleNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model, optimizer, start_epoch, last_loss = load_checkpoint_for_training(
    model, optimizer, 'checkpoint_epoch_10.pt', device=device
)

# Continue training from epoch 11
for epoch in range(start_epoch + 1, 100):
    # Training loop
    pass

Handling Device Transitions

Moving models between GPU and CPU requires explicit device mapping. Here’s how to handle common scenarios:

# Scenario 1: Trained on GPU, deploy on CPU
checkpoint = torch.load('model_gpu.pt', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

# Scenario 2: Trained on CPU, deploy on GPU
checkpoint = torch.load('model_cpu.pt', map_location=torch.device('cuda:0'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to('cuda:0')

# Scenario 3: Flexible loading with automatic device detection
def smart_load(filepath):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(filepath, map_location=device)
    model = SimpleNet()
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    return model, device

Transfer Learning and Partial Loading

Transfer learning often requires loading only specific layers. PyTorch’s state_dict makes this straightforward:

def load_pretrained_backbone(model, pretrained_path, freeze=True):
    """Load pretrained weights for specific layers."""
    pretrained_dict = torch.load(pretrained_path)
    model_dict = model.state_dict()
    
    # Filter out layers you don't want to load (e.g., final classifier)
    pretrained_dict = {
        k: v for k, v in pretrained_dict.items() 
        if k in model_dict and 'fc2' not in k
    }
    
    # Update current model dict
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    
    # Optionally freeze loaded layers
    if freeze:
        for name, param in model.named_parameters():
            if 'fc2' not in name:
                param.requires_grad = False
    
    return model

This pattern is essential for fine-tuning pretrained models on new tasks.

Common Pitfalls and Error Handling

Model loading fails for predictable reasons. Build defensive loading functions:

def safe_load_model(model_class, filepath, device='cpu', strict=True):
    """Robust model loading with error handling."""
    try:
        model = model_class()
        checkpoint = torch.load(filepath, map_location=device)
        
        # Handle both direct state_dict and checkpoint formats
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        
        model.load_state_dict(state_dict, strict=strict)
        model.to(device)
        model.eval()
        
        return model
    
    except FileNotFoundError:
        print(f"Model file not found: {filepath}")
        raise
    except RuntimeError as e:
        print(f"Error loading state_dict: {e}")
        print("This often means architecture mismatch between saved and current model")
        raise
    except Exception as e:
        print(f"Unexpected error loading model: {e}")
        raise

The strict=False parameter allows loading when there are missing or unexpected keys—useful when architectures evolve but you want to preserve compatible weights.

Production Considerations

In production, add versioning and metadata to your saves:

import datetime
import hashlib

def save_versioned_model(model, metadata, save_dir='models'):
    """Save model with version tracking and metadata."""
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Create content hash for integrity checking
    state_dict_str = str(model.state_dict())
    content_hash = hashlib.md5(state_dict_str.encode()).hexdigest()[:8]
    
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'metadata': {
            'timestamp': timestamp,
            'content_hash': content_hash,
            'pytorch_version': torch.__version__,
            **metadata
        }
    }
    
    filename = f"model_{timestamp}_{content_hash}.pt"
    filepath = os.path.join(save_dir, filename)
    
    torch.save(checkpoint, filepath)
    return filepath

# Usage
filepath = save_versioned_model(
    model,
    metadata={
        'accuracy': 0.94,
        'dataset': 'MNIST',
        'training_epochs': 50
    }
)

For cloud deployment, integrate with object storage:

def save_to_s3(model, bucket_name, key):
    """Save model directly to S3."""
    import boto3
    import io
    
    buffer = io.BytesIO()
    torch.save(model.state_dict(), buffer)
    buffer.seek(0)
    
    s3 = boto3.client('s3')
    s3.upload_fileobj(buffer, bucket_name, key)
    print(f"Model saved to s3://{bucket_name}/{key}")

Model persistence in PyTorch is straightforward when you follow the state_dict pattern. Save complete checkpoints during training, handle device mapping explicitly, and add versioning for production deployments. These practices ensure your models remain portable, debuggable, and production-ready.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.