How to Implement a GAN in TensorFlow

Generative Adversarial Networks (GANs) represent one of the most exciting developments in deep learning. Introduced by Ian Goodfellow in 2014, GANs learn to generate new data that resembles a...

Key Insights

  • GANs consist of two competing networks—a generator that creates fake data and a discriminator that learns to distinguish real from fake—trained simultaneously through adversarial loss functions
  • Proper training requires careful balance: train the discriminator on both real and generated samples, then update the generator to fool the discriminator, alternating these steps throughout training
  • Use separate Adam optimizers with different learning rates for each network, apply batch normalization in the generator, and add dropout in the discriminator to stabilize training and prevent mode collapse

Introduction to GANs

Generative Adversarial Networks (GANs) represent one of the most exciting developments in deep learning. Introduced by Ian Goodfellow in 2014, GANs learn to generate new data that resembles a training dataset through an adversarial process between two neural networks.

The architecture consists of two components: a generator network that creates synthetic data from random noise, and a discriminator network that attempts to classify whether samples are real or generated. During training, the generator learns to produce increasingly realistic samples to fool the discriminator, while the discriminator becomes better at detecting fakes. This adversarial dynamic drives both networks to improve, eventually producing a generator capable of creating highly realistic synthetic data.

GANs have numerous practical applications: generating photorealistic images, creating training data for scenarios where real data is scarce, image-to-image translation, super-resolution enhancement, and even generating synthetic medical images for research. In this article, we’ll implement a GAN in TensorFlow to generate handwritten digits using the MNIST dataset.

Setting Up the Environment

Before building our GAN, ensure you have TensorFlow 2.x installed along with necessary dependencies. Install them using pip:

pip install tensorflow numpy matplotlib

We’ll use the MNIST dataset, which contains 60,000 training images of handwritten digits. TensorFlow provides convenient access to this dataset through tf.keras.datasets.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers

# Load and preprocess MNIST dataset
(x_train, _), (_, _) = keras.datasets.mnist.load_data()

# Normalize images to [-1, 1] range to match tanh activation
x_train = x_train.astype('float32')
x_train = (x_train - 127.5) / 127.5

# Reshape to include channel dimension
x_train = x_train.reshape(x_train.shape[0], 784)

# Configuration
BUFFER_SIZE = 60000
BATCH_SIZE = 256
NOISE_DIM = 100
EPOCHS = 50

# Create batched dataset
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Note that we normalize pixel values to the range [-1, 1] rather than [0, 1]. This matches the output range of the tanh activation function we’ll use in the generator’s final layer, which helps stabilize training.

Building the Generator Network

The generator transforms random noise vectors into synthetic images. It takes a noise vector (typically sampled from a normal distribution) and progressively upsamples it through dense layers until it produces an output matching the dimensions of real images.

def build_generator():
    model = keras.Sequential([
        # Input layer takes noise vector
        layers.Dense(256, input_shape=(NOISE_DIM,)),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        # Output layer produces flattened 28x28 image
        layers.Dense(784, activation='tanh')
    ], name='generator')
    
    return model

generator = build_generator()
generator.summary()

Key architectural choices:

  • LeakyReLU activation: Allows small gradients for negative values, preventing “dying ReLU” problems common in GANs
  • Batch normalization: Stabilizes training by normalizing layer inputs, crucial for GAN convergence
  • Tanh output activation: Produces values in [-1, 1], matching our normalized input data range
  • Progressive expansion: Each layer increases capacity, allowing the network to learn hierarchical features

Building the Discriminator Network

The discriminator is a binary classifier that determines whether an input image is real or generated. It should be powerful enough to provide useful gradients to the generator but not so powerful that it perfectly classifies everything, which would stop the generator from learning.

def build_discriminator():
    model = keras.Sequential([
        layers.Dense(512, input_shape=(784,)),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        # Output layer: probability that input is real
        layers.Dense(1, activation='sigmoid')
    ], name='discriminator')
    
    return model

discriminator = build_discriminator()
discriminator.summary()

Important design considerations:

  • Dropout layers: Prevent the discriminator from overfitting and memorizing training samples, which would cause mode collapse
  • Sigmoid output: Produces a probability score between 0 (fake) and 1 (real)
  • Simpler architecture than generator: The discriminator’s task is easier, so we use fewer parameters to maintain balance
  • No batch normalization: In the discriminator, batch normalization can sometimes cause training instability

Defining Loss Functions and Optimizers

GANs use binary cross-entropy loss for both networks, but with different targets. The discriminator maximizes its ability to classify real vs. fake, while the generator minimizes the discriminator’s ability to detect fakes.

# Binary cross-entropy loss
cross_entropy = keras.losses.BinaryCrossentropy()

def discriminator_loss(real_output, fake_output):
    # Discriminator should output 1 for real images
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    # Discriminator should output 0 for fake images
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    # Generator wants discriminator to output 1 for fake images
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# Separate optimizers for each network
generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

We use Adam optimizers with a lower learning rate (0.0002) and modified beta_1 (0.5) compared to default settings. These hyperparameters have been empirically shown to work well for GAN training, providing more stable convergence.

Training Loop Implementation

The training loop alternates between updating the discriminator and generator. Each iteration involves: (1) training the discriminator on real and fake samples, (2) training the generator to fool the discriminator.

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate fake images
        generated_images = generator(noise, training=True)
        
        # Get discriminator outputs
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        # Calculate losses
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
    
    # Calculate gradients
    gradients_of_generator = gen_tape.gradient(
        gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(
        disc_loss, discriminator.trainable_variables)
    
    # Apply gradients
    generator_optimizer.apply_gradients(
        zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(
        zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

def train(dataset, epochs):
    for epoch in range(epochs):
        gen_loss_avg = keras.metrics.Mean()
        disc_loss_avg = keras.metrics.Mean()
        
        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)
            gen_loss_avg.update_state(gen_loss)
            disc_loss_avg.update_state(disc_loss)
        
        # Print progress
        if (epoch + 1) % 5 == 0:
            print(f'Epoch {epoch + 1}, Gen Loss: {gen_loss_avg.result():.4f}, '
                  f'Disc Loss: {disc_loss_avg.result():.4f}')
            generate_and_save_images(generator, epoch + 1)

# Start training
train(train_dataset, EPOCHS)

The @tf.function decorator compiles the training step into a TensorFlow graph, significantly speeding up execution. We use tf.GradientTape to record operations for automatic differentiation, computing gradients for both networks simultaneously.

Visualizing Results and Evaluation

Monitoring generated samples throughout training helps you assess GAN performance and detect issues like mode collapse (when the generator produces limited variety) or training divergence.

def generate_and_save_images(model, epoch, test_input=None):
    if test_input is None:
        test_input = tf.random.normal([16, NOISE_DIM])
    
    predictions = model(test_input, training=False)
    predictions = (predictions + 1) / 2.0  # Rescale from [-1,1] to [0,1]
    
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i].numpy().reshape(28, 28), cmap='gray')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(f'generated_epoch_{epoch}.png')
    plt.close()

# Generate final samples
noise = tf.random.normal([16, NOISE_DIM])
generated_images = generator(noise, training=False)
generate_and_save_images(generator, 'final', noise)

For quantitative evaluation, consider implementing metrics like Inception Score or Fréchet Inception Distance (FID), though visual inspection often suffices for simple datasets like MNIST.

GANs are notoriously difficult to train, but following these practices—balanced architectures, proper normalization, separate optimizers, and careful monitoring—will help you achieve stable training and generate high-quality synthetic data. Experiment with different architectures, hyperparameters, and training techniques to improve your results.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.