How to Implement a GAN in TensorFlow
Generative Adversarial Networks (GANs) represent one of the most exciting developments in deep learning. Introduced by Ian Goodfellow in 2014, GANs learn to generate new data that resembles a...
Key Insights
- GANs consist of two competing networks—a generator that creates fake data and a discriminator that learns to distinguish real from fake—trained simultaneously through adversarial loss functions
- Proper training requires careful balance: train the discriminator on both real and generated samples, then update the generator to fool the discriminator, alternating these steps throughout training
- Use separate Adam optimizers with different learning rates for each network, apply batch normalization in the generator, and add dropout in the discriminator to stabilize training and prevent mode collapse
Introduction to GANs
Generative Adversarial Networks (GANs) represent one of the most exciting developments in deep learning. Introduced by Ian Goodfellow in 2014, GANs learn to generate new data that resembles a training dataset through an adversarial process between two neural networks.
The architecture consists of two components: a generator network that creates synthetic data from random noise, and a discriminator network that attempts to classify whether samples are real or generated. During training, the generator learns to produce increasingly realistic samples to fool the discriminator, while the discriminator becomes better at detecting fakes. This adversarial dynamic drives both networks to improve, eventually producing a generator capable of creating highly realistic synthetic data.
GANs have numerous practical applications: generating photorealistic images, creating training data for scenarios where real data is scarce, image-to-image translation, super-resolution enhancement, and even generating synthetic medical images for research. In this article, we’ll implement a GAN in TensorFlow to generate handwritten digits using the MNIST dataset.
Setting Up the Environment
Before building our GAN, ensure you have TensorFlow 2.x installed along with necessary dependencies. Install them using pip:
pip install tensorflow numpy matplotlib
We’ll use the MNIST dataset, which contains 60,000 training images of handwritten digits. TensorFlow provides convenient access to this dataset through tf.keras.datasets.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
# Load and preprocess MNIST dataset
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
# Normalize images to [-1, 1] range to match tanh activation
x_train = x_train.astype('float32')
x_train = (x_train - 127.5) / 127.5
# Reshape to include channel dimension
x_train = x_train.reshape(x_train.shape[0], 784)
# Configuration
BUFFER_SIZE = 60000
BATCH_SIZE = 256
NOISE_DIM = 100
EPOCHS = 50
# Create batched dataset
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Note that we normalize pixel values to the range [-1, 1] rather than [0, 1]. This matches the output range of the tanh activation function we’ll use in the generator’s final layer, which helps stabilize training.
Building the Generator Network
The generator transforms random noise vectors into synthetic images. It takes a noise vector (typically sampled from a normal distribution) and progressively upsamples it through dense layers until it produces an output matching the dimensions of real images.
def build_generator():
model = keras.Sequential([
# Input layer takes noise vector
layers.Dense(256, input_shape=(NOISE_DIM,)),
layers.LeakyReLU(alpha=0.2),
layers.BatchNormalization(momentum=0.8),
layers.Dense(512),
layers.LeakyReLU(alpha=0.2),
layers.BatchNormalization(momentum=0.8),
layers.Dense(1024),
layers.LeakyReLU(alpha=0.2),
layers.BatchNormalization(momentum=0.8),
# Output layer produces flattened 28x28 image
layers.Dense(784, activation='tanh')
], name='generator')
return model
generator = build_generator()
generator.summary()
Key architectural choices:
- LeakyReLU activation: Allows small gradients for negative values, preventing “dying ReLU” problems common in GANs
- Batch normalization: Stabilizes training by normalizing layer inputs, crucial for GAN convergence
- Tanh output activation: Produces values in [-1, 1], matching our normalized input data range
- Progressive expansion: Each layer increases capacity, allowing the network to learn hierarchical features
Building the Discriminator Network
The discriminator is a binary classifier that determines whether an input image is real or generated. It should be powerful enough to provide useful gradients to the generator but not so powerful that it perfectly classifies everything, which would stop the generator from learning.
def build_discriminator():
model = keras.Sequential([
layers.Dense(512, input_shape=(784,)),
layers.LeakyReLU(alpha=0.2),
layers.Dropout(0.3),
layers.Dense(256),
layers.LeakyReLU(alpha=0.2),
layers.Dropout(0.3),
# Output layer: probability that input is real
layers.Dense(1, activation='sigmoid')
], name='discriminator')
return model
discriminator = build_discriminator()
discriminator.summary()
Important design considerations:
- Dropout layers: Prevent the discriminator from overfitting and memorizing training samples, which would cause mode collapse
- Sigmoid output: Produces a probability score between 0 (fake) and 1 (real)
- Simpler architecture than generator: The discriminator’s task is easier, so we use fewer parameters to maintain balance
- No batch normalization: In the discriminator, batch normalization can sometimes cause training instability
Defining Loss Functions and Optimizers
GANs use binary cross-entropy loss for both networks, but with different targets. The discriminator maximizes its ability to classify real vs. fake, while the generator minimizes the discriminator’s ability to detect fakes.
# Binary cross-entropy loss
cross_entropy = keras.losses.BinaryCrossentropy()
def discriminator_loss(real_output, fake_output):
# Discriminator should output 1 for real images
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
# Discriminator should output 0 for fake images
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
# Generator wants discriminator to output 1 for fake images
return cross_entropy(tf.ones_like(fake_output), fake_output)
# Separate optimizers for each network
generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
We use Adam optimizers with a lower learning rate (0.0002) and modified beta_1 (0.5) compared to default settings. These hyperparameters have been empirically shown to work well for GAN training, providing more stable convergence.
Training Loop Implementation
The training loop alternates between updating the discriminator and generator. Each iteration involves: (1) training the discriminator on real and fake samples, (2) training the generator to fool the discriminator.
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# Generate fake images
generated_images = generator(noise, training=True)
# Get discriminator outputs
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
# Calculate losses
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
# Calculate gradients
gradients_of_generator = gen_tape.gradient(
gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(
disc_loss, discriminator.trainable_variables)
# Apply gradients
generator_optimizer.apply_gradients(
zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(
zip(gradients_of_discriminator, discriminator.trainable_variables))
return gen_loss, disc_loss
def train(dataset, epochs):
for epoch in range(epochs):
gen_loss_avg = keras.metrics.Mean()
disc_loss_avg = keras.metrics.Mean()
for image_batch in dataset:
gen_loss, disc_loss = train_step(image_batch)
gen_loss_avg.update_state(gen_loss)
disc_loss_avg.update_state(disc_loss)
# Print progress
if (epoch + 1) % 5 == 0:
print(f'Epoch {epoch + 1}, Gen Loss: {gen_loss_avg.result():.4f}, '
f'Disc Loss: {disc_loss_avg.result():.4f}')
generate_and_save_images(generator, epoch + 1)
# Start training
train(train_dataset, EPOCHS)
The @tf.function decorator compiles the training step into a TensorFlow graph, significantly speeding up execution. We use tf.GradientTape to record operations for automatic differentiation, computing gradients for both networks simultaneously.
Visualizing Results and Evaluation
Monitoring generated samples throughout training helps you assess GAN performance and detect issues like mode collapse (when the generator produces limited variety) or training divergence.
def generate_and_save_images(model, epoch, test_input=None):
if test_input is None:
test_input = tf.random.normal([16, NOISE_DIM])
predictions = model(test_input, training=False)
predictions = (predictions + 1) / 2.0 # Rescale from [-1,1] to [0,1]
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i].numpy().reshape(28, 28), cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig(f'generated_epoch_{epoch}.png')
plt.close()
# Generate final samples
noise = tf.random.normal([16, NOISE_DIM])
generated_images = generator(noise, training=False)
generate_and_save_images(generator, 'final', noise)
For quantitative evaluation, consider implementing metrics like Inception Score or Fréchet Inception Distance (FID), though visual inspection often suffices for simple datasets like MNIST.
GANs are notoriously difficult to train, but following these practices—balanced architectures, proper normalization, separate optimizers, and careful monitoring—will help you achieve stable training and generate high-quality synthetic data. Experiment with different architectures, hyperparameters, and training techniques to improve your results.