How to Implement a VAE in PyTorch

Variational Autoencoders (VAEs) are generative models that learn to encode data into a probabilistic latent space. Unlike standard autoencoders that map inputs to fixed-point representations, VAEs...

Key Insights

  • VAEs learn a probabilistic latent representation by encoding inputs as distributions rather than fixed vectors, enabling controlled generation of new samples
  • The reparameterization trick (z = μ + σ ⊙ ε) is essential for backpropagation through stochastic sampling operations
  • VAE loss balances reconstruction accuracy (BCE/MSE) with latent space regularization (KL divergence), creating a structured generative model

Introduction to Variational Autoencoders

Variational Autoencoders (VAEs) are generative models that learn to encode data into a probabilistic latent space. Unlike standard autoencoders that map inputs to fixed-point representations, VAEs encode inputs as probability distributions—specifically, parameters of a Gaussian distribution.

The key difference is fundamental: a standard autoencoder learns z = encoder(x), while a VAE learns μ, σ = encoder(x) and then samples z ~ N(μ, σ²). This probabilistic approach enables VAEs to generate new samples by sampling from the learned latent distribution.

The VAE architecture consists of three components:

# Conceptual flow of a VAE
x (input) 
   Encoder  (μ, log_σ²)  # Encode to distribution parameters
   Reparameterization  z  # Sample latent vector
   Decoder  x̂ (reconstruction)

# Loss = Reconstruction Loss + KL Divergence

The reparameterization trick is what makes VAEs trainable. Instead of sampling directly from N(μ, σ²) (which isn’t differentiable), we sample ε ~ N(0, 1) and compute z = μ + σ ⊙ ε. This shifts the randomness to ε, allowing gradients to flow through μ and σ.

Setting Up the Encoder Network

The encoder maps input data to the parameters of a latent distribution. For image data, we typically use convolutional layers; for tabular data, fully connected layers work well.

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(Encoder, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
    def forward(self, x):
        h = torch.relu(self.fc1(x))
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        return mu, log_var

For convolutional architectures working with images:

class ConvEncoder(nn.Module):
    def __init__(self, latent_dim=20):
        super(ConvEncoder, self).__init__()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 14x14 -> 7x7
            nn.ReLU(),
            nn.Flatten()  # 64 * 7 * 7 = 3136
        )
        
        self.fc_mu = nn.Linear(3136, latent_dim)
        self.fc_logvar = nn.Linear(3136, latent_dim)
        
    def forward(self, x):
        h = self.conv_layers(x)
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        return mu, log_var

Note that we output log_var rather than variance directly. This ensures numerical stability and allows the network to represent a wider range of variance values.

Implementing the Reparameterization Trick

The reparameterization trick is the mathematical core that makes VAEs trainable with standard backpropagation.

def reparameterize(mu, log_var):
    """
    Reparameterization trick: z = μ + σ * ε, where ε ~ N(0,1)
    
    Args:
        mu: Mean of the latent distribution (batch_size, latent_dim)
        log_var: Log variance of the latent distribution (batch_size, latent_dim)
    
    Returns:
        Sampled latent vector z (batch_size, latent_dim)
    """
    std = torch.exp(0.5 * log_var)  # Convert log_var to std
    eps = torch.randn_like(std)      # Sample ε from N(0,1)
    z = mu + eps * std               # Reparameterization
    return z

During training, this function samples from the learned distribution. During inference, you can either sample for generation or use mu directly for deterministic encoding.

Building the Decoder Network

The decoder reconstructs the input from latent samples. It should mirror the encoder’s architecture in reverse.

class Decoder(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=400, output_dim=784):
        super(Decoder, self).__init__()
        
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, z):
        h = torch.relu(self.fc1(z))
        x_recon = torch.sigmoid(self.fc2(h))  # Sigmoid for [0,1] pixel values
        return x_recon

For convolutional decoders:

class ConvDecoder(nn.Module):
    def __init__(self, latent_dim=20):
        super(ConvDecoder, self).__init__()
        
        self.fc = nn.Linear(latent_dim, 3136)  # Project to 64 * 7 * 7
        
        self.conv_layers = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # 7x7 -> 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),   # 14x14 -> 28x28
            nn.Sigmoid()
        )
        
    def forward(self, z):
        h = self.fc(z)
        h = h.view(-1, 64, 7, 7)  # Reshape to spatial dimensions
        x_recon = self.conv_layers(h)
        return x_recon

The output activation depends on your data: sigmoid for normalized images, tanh for [-1, 1] range, or no activation for unbounded data.

Defining the Loss Function

The VAE loss combines two terms: reconstruction loss and KL divergence. The reconstruction loss measures how well the decoder reconstructs inputs, while KL divergence regularizes the latent space to follow a standard normal distribution.

def vae_loss(x_recon, x, mu, log_var):
    """
    VAE loss = Reconstruction Loss + KL Divergence
    
    Args:
        x_recon: Reconstructed input (batch_size, input_dim)
        x: Original input (batch_size, input_dim)
        mu: Mean of latent distribution (batch_size, latent_dim)
        log_var: Log variance of latent distribution (batch_size, latent_dim)
    
    Returns:
        Total loss (scalar)
    """
    # Reconstruction loss (binary cross-entropy for binary/normalized data)
    recon_loss = nn.functional.binary_cross_entropy(
        x_recon, x, reduction='sum'
    )
    
    # KL divergence: -0.5 * sum(1 + log(σ²) - μ² - σ²)
    kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return recon_loss + kl_div

For continuous data, use MSE instead of BCE:

# Alternative reconstruction loss for continuous data
recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum')

The KL divergence formula assumes the prior is N(0, I) and has a closed-form solution. This regularization prevents the latent space from collapsing and ensures smooth interpolation.

Training the VAE

Here’s a complete training implementation combining all components:

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
        
    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var

# Training script
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Hyperparameters
batch_size = 128
learning_rate = 1e-3
epochs = 10
latent_dim = 20

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=784, hidden_dim=400, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
model.train()
for epoch in range(epochs):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, 784).to(device)  # Flatten images
        
        optimizer.zero_grad()
        x_recon, mu, log_var = model(data)
        loss = vae_loss(x_recon, data, mu, log_var)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    avg_loss = train_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

Generating New Samples

Once trained, VAEs can generate new samples by sampling from the latent space and decoding.

def generate_samples(model, num_samples=16):
    """Generate new samples by sampling from N(0, I)"""
    model.eval()
    with torch.no_grad():
        # Sample from standard normal
        z = torch.randn(num_samples, model.encoder.fc_mu.out_features).to(device)
        samples = model.decoder(z)
        samples = samples.view(-1, 1, 28, 28)  # Reshape for visualization
    return samples

# Generate samples
new_samples = generate_samples(model, num_samples=16)

For latent space interpolation between two images:

def interpolate_latent(model, x1, x2, steps=10):
    """Interpolate between two images in latent space"""
    model.eval()
    with torch.no_grad():
        mu1, _ = model.encoder(x1)
        mu2, _ = model.encoder(x2)
        
        # Linear interpolation in latent space
        interpolations = []
        for alpha in torch.linspace(0, 1, steps):
            z_interp = (1 - alpha) * mu1 + alpha * mu2
            x_interp = model.decoder(z_interp)
            interpolations.append(x_interp)
        
        return torch.stack(interpolations)

# Interpolate between two MNIST digits
x1 = train_dataset[0][0].view(1, -1).to(device)
x2 = train_dataset[1][0].view(1, -1).to(device)
interpolated = interpolate_latent(model, x1, x2, steps=10)

This interpolation demonstrates the smooth, continuous latent space learned by the VAE. Unlike standard autoencoders, VAEs ensure that all points in the latent space decode to meaningful outputs, making them powerful generative models.

The key to successful VAE implementation is balancing reconstruction quality with latent space structure. If reconstruction is poor, increase the hidden dimensions or reduce the KL weight. If the latent space isn’t being used (posterior collapse), try β-VAE variants or cyclical annealing schedules.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.