How to Implement a VAE in TensorFlow

Variational Autoencoders represent a powerful class of generative models that learn compressed representations of data while maintaining the ability to generate new, realistic samples. Unlike...

Key Insights

  • VAEs learn a probabilistic latent representation by encoding inputs as distributions rather than fixed points, enabling smooth interpolation and controlled generation of new samples
  • The reparameterization trick is essential for training VAEs—it moves randomness outside the gradient path by sampling from N(μ, σ²) as μ + σ * ε where ε ~ N(0,1)
  • The VAE loss balances reconstruction accuracy with latent space structure through a weighted combination of reconstruction loss and KL divergence regularization

Introduction to Variational Autoencoders

Variational Autoencoders represent a powerful class of generative models that learn compressed representations of data while maintaining the ability to generate new, realistic samples. Unlike standard autoencoders that map inputs to fixed points in latent space, VAEs encode inputs as probability distributions—specifically, as means and variances of Gaussian distributions.

This probabilistic approach provides several advantages. The latent space becomes continuous and well-structured, allowing smooth interpolation between points. You can sample random points from the latent space to generate entirely new data. VAEs excel at dimensionality reduction, anomaly detection (by measuring reconstruction error), and conditional generation tasks.

The architecture consists of three components: an encoder network that maps inputs to latent distribution parameters, a sampling operation that draws from these distributions, and a decoder network that reconstructs the original input from sampled latent vectors. The key innovation is the reparameterization trick, which enables gradient-based optimization despite the stochastic sampling operation.

Understanding the VAE Loss Function

The VAE loss function combines two competing objectives. The reconstruction loss measures how accurately the decoder recreates the original input from the latent representation. For image data, this is typically binary cross-entropy or mean squared error. This component alone would create a standard autoencoder.

The second component, KL divergence, regularizes the latent space by encouraging the learned distributions to approximate a standard normal distribution N(0, I). This regularization prevents the encoder from “cheating” by encoding each input to a completely separate region of latent space with zero variance. By constraining distributions toward N(0, I), we ensure the latent space remains continuous and structured.

The mathematical formulation is: Loss = Reconstruction_Loss + β * KL_Divergence, where β controls the trade-off. Higher β values produce more structured latent spaces but potentially worse reconstructions.

Here’s how to calculate KL divergence for Gaussian distributions in TensorFlow:

import tensorflow as tf

def kl_divergence_loss(mu, log_var):
    """
    Calculate KL divergence between learned distribution and N(0,1).
    For multivariate Gaussian: KL = -0.5 * sum(1 + log(σ²) - μ² - σ²)
    """
    kl_loss = -0.5 * tf.reduce_sum(
        1 + log_var - tf.square(mu) - tf.exp(log_var),
        axis=1
    )
    return tf.reduce_mean(kl_loss)

# Example calculation
mu = tf.constant([[0.5, -0.3, 0.8]])
log_var = tf.constant([[-0.2, -0.5, 0.1]])
kl_loss = kl_divergence_loss(mu, log_var)
print(f"KL Divergence: {kl_loss.numpy():.4f}")

Building the Encoder Network

The encoder transforms input data into parameters of a latent distribution. For a Gaussian distribution, we need two outputs: the mean (μ) and log-variance (log σ²). We use log-variance instead of variance directly for numerical stability and to allow negative values during training.

For MNIST images (28x28 grayscale), a typical encoder uses convolutional layers to extract features, followed by dense layers that output the distribution parameters. The architecture should progressively reduce spatial dimensions while increasing channel depth.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class Encoder(keras.Model):
    def __init__(self, latent_dim=2, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        
        # Convolutional feature extraction
        self.conv1 = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')
        self.conv2 = layers.Conv2D(64, 3, strides=2, padding='same', activation='relu')
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(256, activation='relu')
        
        # Distribution parameters
        self.mu_layer = layers.Dense(latent_dim, name='mu')
        self.log_var_layer = layers.Dense(latent_dim, name='log_var')
    
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.dense(x)
        
        mu = self.mu_layer(x)
        log_var = self.log_var_layer(x)
        
        return mu, log_var

This encoder reduces 28x28 images to a 2D or higher-dimensional latent representation. The separate output layers for μ and log σ² are crucial—they give the network flexibility to learn both the location and uncertainty of each encoding.

Implementing the Reparameterization Trick

Direct sampling from N(μ, σ²) is not differentiable, which breaks backpropagation. The reparameterization trick solves this by expressing the random variable as a deterministic function of the parameters plus independent noise: z = μ + σ * ε, where ε ~ N(0, 1).

This reformulation moves the randomness outside the gradient path. Gradients can flow through μ and σ, while ε remains a constant during backpropagation. This simple trick is what makes VAEs trainable with standard gradient descent.

class Sampling(layers.Layer):
    """Sampling layer using the reparameterization trick."""
    
    def call(self, inputs):
        mu, log_var = inputs
        batch_size = tf.shape(mu)[0]
        latent_dim = tf.shape(mu)[1]
        
        # Sample epsilon from standard normal
        epsilon = tf.random.normal(shape=(batch_size, latent_dim))
        
        # Reparameterization: z = mu + sigma * epsilon
        # Use exp(0.5 * log_var) to get sigma from log_var
        return mu + tf.exp(0.5 * log_var) * epsilon

The use of exp(0.5 * log_var) converts log-variance to standard deviation (σ = √(σ²) = √(exp(log σ²)) = exp(0.5 * log σ²)). This layer can be used like any other Keras layer in your model.

Building the Decoder Network

The decoder mirrors the encoder architecture in reverse, transforming latent vectors back to the original data space. It should progressively upsample from the latent dimension to the full input size.

class Decoder(keras.Model):
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        
        self.dense1 = layers.Dense(7 * 7 * 64, activation='relu')
        self.reshape = layers.Reshape((7, 7, 64))
        
        # Transposed convolutions for upsampling
        self.conv_t1 = layers.Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')
        self.conv_t2 = layers.Conv2DTranspose(32, 3, strides=2, padding='same', activation='relu')
        
        # Output layer - sigmoid for normalized pixel values [0, 1]
        self.output_layer = layers.Conv2D(1, 3, padding='same', activation='sigmoid')
    
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.reshape(x)
        x = self.conv_t1(x)
        x = self.conv_t2(x)
        return self.output_layer(x)

The sigmoid activation in the output layer is critical when using binary cross-entropy loss—it ensures outputs are in [0, 1] range, matching normalized pixel values. For other data types, adjust the activation accordingly (tanh for [-1, 1] range, linear for continuous unbounded data).

Training the Complete VAE Model

Combining all components requires a custom training step that computes both loss components. We’ll create a VAE class that inherits from keras.Model and overrides the train_step method.

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.sampling = Sampling()
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
    
    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]
    
    def train_step(self, data):
        with tf.GradientTape() as tape:
            # Encode
            mu, log_var = self.encoder(data)
            # Sample
            z = self.sampling([mu, log_var])
            # Decode
            reconstruction = self.decoder(z)
            
            # Reconstruction loss
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction),
                    axis=(1, 2)
                )
            )
            
            # KL divergence loss
            kl_loss = -0.5 * tf.reduce_mean(
                tf.reduce_sum(1 + log_var - tf.square(mu) - tf.exp(log_var), axis=1)
            )
            
            # Total loss
            total_loss = reconstruction_loss + kl_loss
        
        # Compute gradients and update weights
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

# Training setup
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)

encoder = Encoder(latent_dim=2)
decoder = Decoder()
vae = VAE(encoder, decoder)

vae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3))
vae.fit(x_train, epochs=30, batch_size=128)

Monitor all three metrics during training. If reconstruction loss dominates, the model focuses too much on accuracy at the expense of latent structure. If KL loss dominates, increase β or adjust the architecture. Aim for balanced contributions from both terms.

Generating New Samples and Visualizing Results

Once trained, you can generate new samples by sampling random points from the prior distribution N(0, I) and passing them through the decoder. You can also interpolate between two points in latent space to create smooth transitions.

import matplotlib.pyplot as plt
import numpy as np

# Generate new samples
def generate_samples(decoder, n_samples=10, latent_dim=2):
    random_latent_vectors = tf.random.normal(shape=(n_samples, latent_dim))
    generated_images = decoder(random_latent_vectors)
    return generated_images

# Visualize generated samples
generated = generate_samples(decoder, n_samples=10)
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated[i].numpy().squeeze(), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()

# Latent space interpolation
def interpolate_latent(decoder, start_point, end_point, steps=10):
    interpolated = []
    for alpha in np.linspace(0, 1, steps):
        point = start_point * (1 - alpha) + end_point * alpha
        interpolated.append(decoder(point[np.newaxis, :]))
    return interpolated

# Interpolate between two random points
start = tf.random.normal(shape=(2,))
end = tf.random.normal(shape=(2,))
interpolated_images = interpolate_latent(decoder, start, end, steps=10)

fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i, ax in enumerate(axes):
    ax.imshow(interpolated_images[i].numpy().squeeze(), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()

# Visualize reconstruction quality
def plot_reconstructions(vae, test_data, n_samples=10):
    mu, log_var = vae.encoder(test_data[:n_samples])
    z = vae.sampling([mu, log_var])
    reconstructions = vae.decoder(z)
    
    fig, axes = plt.subplots(2, n_samples, figsize=(15, 3))
    for i in range(n_samples):
        axes[0, i].imshow(test_data[i].numpy().squeeze(), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(reconstructions[i].numpy().squeeze(), cmap='gray')
        axes[1, i].axis('off')
    axes[0, 0].set_ylabel('Original', size=12)
    axes[1, 0].set_ylabel('Reconstructed', size=12)
    plt.tight_layout()
    plt.show()

x_test = x_test.astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1)
plot_reconstructions(vae, x_test)

The quality of generated samples depends heavily on the latent dimension, network capacity, and training duration. For MNIST, a 2D latent space provides nice visualizations but limits expressiveness. Try 10-20 dimensions for more complex datasets. Experiment with β-VAE (varying β in the loss function) to control the trade-off between reconstruction quality and latent space disentanglement.

VAEs provide a principled framework for learning generative models with interpretable latent representations. The implementation in TensorFlow is straightforward once you understand the core components: probabilistic encoding, reparameterization for gradient flow, and the dual-objective loss function. Start with this foundation and experiment with architectural variations, different β values, and conditional VAEs for controlled generation.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.