How to Implement an Autoencoder in PyTorch

Autoencoders are neural networks designed to learn efficient data representations in an unsupervised manner. They work by compressing input data into a lower-dimensional latent space through an...

Key Insights

  • Autoencoders compress data into a lower-dimensional latent space and reconstruct it, making them ideal for dimensionality reduction, anomaly detection, and denoising tasks
  • PyTorch’s modular design lets you build autoencoders as simple nn.Module classes with separate encoder and decoder components, supporting both fully-connected and convolutional architectures
  • Training autoencoders requires only reconstruction loss (MSE or BCE) without labeled data, making them powerful unsupervised learning tools

Introduction to Autoencoders

Autoencoders are neural networks designed to learn efficient data representations in an unsupervised manner. They work by compressing input data into a lower-dimensional latent space through an encoder, then reconstructing the original input from this compressed representation using a decoder. The network learns by minimizing the difference between the input and its reconstruction.

The architecture consists of three main components: the encoder network that maps inputs to the latent space, the latent space (or bottleneck) that holds the compressed representation, and the decoder network that reconstructs the input from the latent representation. This bottleneck forces the network to learn the most salient features of the data.

Autoencoders excel at several practical tasks. Use them for dimensionality reduction as a nonlinear alternative to PCA, for anomaly detection by identifying inputs that reconstruct poorly, for denoising by training on corrupted inputs, and as pretraining for downstream supervised tasks. They’re particularly effective when you have abundant unlabeled data but limited labeled examples.

Setting Up the Environment

You’ll need PyTorch, torchvision for dataset utilities, and matplotlib for visualization. Install these packages if you haven’t already:

pip install torch torchvision matplotlib

Here are the necessary imports:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Building the Autoencoder Architecture

Let’s start with a fully-connected autoencoder for MNIST images. The encoder progressively reduces dimensionality through linear layers, and the decoder mirrors this structure to reconstruct the original dimensions.

class FullyConnectedAutoencoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=32):
        super(FullyConnectedAutoencoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()  # Output in [0, 1] range
        )
    
    def forward(self, x):
        # Flatten input
        x = x.view(x.size(0), -1)
        # Encode
        latent = self.encoder(x)
        # Decode
        reconstructed = self.decoder(latent)
        return reconstructed

For image data, convolutional autoencoders often perform better by preserving spatial structure:

class ConvolutionalAutoencoder(nn.Module):
    def __init__(self, latent_dim=64):
        super(ConvolutionalAutoencoder, self).__init__()
        
        # Encoder: 28x28 -> 14x14 -> 7x7 -> latent
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 14x14
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 7x7
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, latent_dim),
            nn.ReLU()
        )
        
        # Decoder: latent -> 7x7 -> 14x14 -> 28x28
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 32 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (32, 7, 7)),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),   # 28x28
            nn.Sigmoid()
        )
    
    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed

The convolutional version uses stride-2 convolutions for downsampling and transposed convolutions for upsampling, maintaining spatial relationships throughout the network.

Preparing the Dataset

MNIST provides an ideal starting point—simple grayscale images with clear structure. We normalize pixel values to [0, 1] and create data loaders for batched training:

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    # MNIST images are already in [0, 1] after ToTensor()
])

# Load datasets
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training the Autoencoder

Use Mean Squared Error (MSE) for continuous outputs or Binary Cross-Entropy (BCE) for binary data. MSE works well for normalized image data. The training loop is straightforward—no labels needed since we’re reconstructing the input itself.

# Initialize model
model = ConvolutionalAutoencoder(latent_dim=64).to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 20
train_losses = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        # Move data to device
        data = data.to(device)
        
        # Forward pass
        reconstructed = model(data)
        loss = criterion(reconstructed, data)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    # Calculate average loss
    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")

# Save the model
torch.save(model.state_dict(), 'autoencoder.pth')

Monitor the loss curve—it should decrease steadily. If loss plateaus early, try adjusting the learning rate or latent dimension size.

Evaluating and Visualizing Results

Visualization reveals how well your autoencoder captures data structure. Compare original images with reconstructions to assess quality:

def visualize_reconstructions(model, data_loader, num_images=10):
    model.eval()
    
    # Get a batch of test images
    data_iter = iter(data_loader)
    images, _ = next(data_iter)
    images = images[:num_images].to(device)
    
    # Generate reconstructions
    with torch.no_grad():
        reconstructed = model(images)
    
    # Move to CPU for plotting
    images = images.cpu()
    reconstructed = reconstructed.cpu()
    
    # Plot
    fig, axes = plt.subplots(2, num_images, figsize=(15, 3))
    
    for i in range(num_images):
        # Original images
        axes[0, i].imshow(images[i].squeeze(), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=12)
        
        # Reconstructed images
        axes[1, i].imshow(reconstructed[i].squeeze(), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('reconstructions.png', dpi=150, bbox_inches='tight')
    plt.show()

# Visualize results
visualize_reconstructions(model, test_loader)

Calculate reconstruction error to quantify performance:

def evaluate_reconstruction_error(model, data_loader):
    model.eval()
    total_loss = 0.0
    criterion = nn.MSELoss()
    
    with torch.no_grad():
        for data, _ in data_loader:
            data = data.to(device)
            reconstructed = model(data)
            loss = criterion(reconstructed, data)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(data_loader)
    print(f"Average reconstruction error: {avg_loss:.6f}")
    return avg_loss

evaluate_reconstruction_error(model, test_loader)

Practical Extensions

Denoising autoencoders learn robust representations by reconstructing clean images from corrupted inputs. Add Gaussian noise during training:

def add_noise(images, noise_factor=0.3):
    noisy_images = images + noise_factor * torch.randn_like(images)
    noisy_images = torch.clamp(noisy_images, 0., 1.)
    return noisy_images

# Modified training loop for denoising
for epoch in range(num_epochs):
    model.train()
    for data, _ in train_loader:
        data = data.to(device)
        
        # Add noise to input
        noisy_data = add_noise(data)
        
        # Reconstruct clean image from noisy input
        reconstructed = model(noisy_data)
        loss = criterion(reconstructed, data)  # Compare to clean data
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Variational autoencoders (VAEs) add a probabilistic layer by learning distributions in latent space rather than fixed encodings. This enables generation of new samples. Sparse autoencoders add L1 regularization to encourage sparse latent representations, improving interpretability.

For production use, consider these optimizations: implement early stopping based on validation loss, use learning rate scheduling to fine-tune convergence, experiment with different latent dimensions to balance compression and reconstruction quality, and add batch normalization for training stability with deeper networks.

Autoencoders provide a powerful foundation for unsupervised learning. Start with these implementations, then adapt the architecture and training procedure to your specific data and use case.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.