How to Implement a GAN in PyTorch
Generative Adversarial Networks (GANs) represent one of the most exciting developments in deep learning. Introduced by Ian Goodfellow in 2014, GANs use a game-theoretic approach where two neural...
Key Insights
- GANs pit two neural networks against each other—a generator creates fake data while a discriminator learns to distinguish real from fake, forcing both to improve through adversarial training
- The training process alternates between updating the discriminator (teaching it to identify fakes) and the generator (teaching it to fool the discriminator), requiring careful balance to prevent one network from dominating
- Mode collapse and training instability are common issues that require techniques like label smoothing, proper learning rates, and monitoring loss curves to ensure both networks improve together
Introduction to GANs
Generative Adversarial Networks (GANs) represent one of the most exciting developments in deep learning. Introduced by Ian Goodfellow in 2014, GANs use a game-theoretic approach where two neural networks compete against each other. The generator creates synthetic data samples, while the discriminator attempts to distinguish between real and generated samples. This adversarial process pushes both networks to improve continuously.
The architecture is deceptively simple: the generator takes random noise as input and transforms it into data that resembles your training set (images, for example). The discriminator receives both real samples from your dataset and fake samples from the generator, outputting a probability score indicating whether each sample is real or fake.
GANs have proven invaluable for image generation, creating photorealistic faces, augmenting limited datasets, style transfer, and even generating synthetic training data for other machine learning models. Let’s build one from scratch in PyTorch.
Setting Up the Environment
First, install the required dependencies. You’ll need PyTorch, torchvision for dataset utilities, and matplotlib for visualization.
pip install torch torchvision matplotlib
Here are the necessary imports for our implementation:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyperparameters
latent_dim = 100
hidden_dim = 256
image_dim = 28 * 28 # MNIST images are 28x28
batch_size = 64
learning_rate = 0.0002
num_epochs = 50
We’ll use the MNIST dataset because it’s simple enough to train quickly while demonstrating GAN fundamentals:
# Load MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])
mnist_data = datasets.MNIST(
root='./data',
train=True,
transform=transform,
download=True
)
dataloader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)
Building the Generator Network
The generator takes a random noise vector (latent vector) and transforms it into a realistic-looking image. For MNIST, we’ll use a simple fully-connected architecture:
class Generator(nn.Module):
def __init__(self, latent_dim, hidden_dim, image_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim * 2),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim * 2, hidden_dim * 4),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim * 4, image_dim),
nn.Tanh() # Output in range [-1, 1] to match normalized data
)
def forward(self, z):
return self.model(z)
# Initialize generator
generator = Generator(latent_dim, hidden_dim, image_dim).to(device)
The Tanh activation in the final layer is crucial—it ensures outputs match our normalized image range of [-1, 1]. LeakyReLU helps prevent dying neurons during training.
Building the Discriminator Network
The discriminator is a binary classifier that outputs the probability that an input image is real:
class Discriminator(nn.Module):
def __init__(self, image_dim, hidden_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(image_dim, hidden_dim * 4),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() # Output probability [0, 1]
)
def forward(self, img):
return self.model(img)
# Initialize discriminator
discriminator = Discriminator(image_dim, hidden_dim).to(device)
Dropout layers help prevent overfitting, which is especially important for the discriminator to avoid it becoming too powerful too quickly.
Training Loop Implementation
The training process alternates between updating the discriminator and generator. The discriminator learns to classify real vs. fake images, while the generator learns to create images that fool the discriminator:
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
# Training loop
for epoch in range(num_epochs):
for batch_idx, (real_images, _) in enumerate(dataloader):
batch_size = real_images.size(0)
real_images = real_images.view(batch_size, -1).to(device)
# Labels for real and fake images
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# =================== Train Discriminator ===================
# Train on real images
outputs = discriminator(real_images)
d_loss_real = criterion(outputs, real_labels)
# Train on fake images
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
outputs = discriminator(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
# Combined discriminator loss
d_loss = d_loss_real + d_loss_fake
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
# =================== Train Generator ===================
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
outputs = discriminator(fake_images)
# Generator tries to fool discriminator
g_loss = criterion(outputs, real_labels)
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
# Print progress
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
Notice the detach() call when training the discriminator on fake images—this prevents gradients from flowing back to the generator during discriminator training.
Visualizing Results
Monitoring generated images throughout training helps you understand whether your GAN is learning properly:
def generate_and_save_images(generator, epoch, latent_dim, num_images=16):
generator.eval()
with torch.no_grad():
z = torch.randn(num_images, latent_dim).to(device)
generated_images = generator(z).cpu().view(-1, 28, 28)
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
ax.imshow(generated_images[i], cmap='gray')
ax.axis('off')
plt.suptitle(f'Generated Images - Epoch {epoch}')
plt.tight_layout()
plt.savefig(f'gan_epoch_{epoch}.png')
plt.close()
generator.train()
# Call this function during training
if (epoch + 1) % 10 == 0:
generate_and_save_images(generator, epoch + 1, latent_dim)
Common Pitfalls and Best Practices
Mode collapse occurs when the generator produces limited variety, creating the same few outputs regardless of input noise. Combat this by monitoring sample diversity and using techniques like mini-batch discrimination or adding noise to discriminator inputs.
Training instability is the most common GAN challenge. If the discriminator becomes too strong, it provides no useful gradient signal to the generator. If the generator dominates, the discriminator learns nothing. Balance is key.
Use label smoothing by replacing real labels (1.0) with values like 0.9, and fake labels (0.0) with 0.1. This prevents the discriminator from becoming overconfident:
real_labels = torch.ones(batch_size, 1).to(device) * 0.9
fake_labels = torch.zeros(batch_size, 1).to(device) + 0.1
Learning rate matters enormously. The value 0.0002 with beta1=0.5 for Adam has become a standard starting point for GANs. Adjust carefully if changing.
Monitor both losses. Neither should converge to zero. Healthy GAN training shows both losses fluctuating around stable values. If discriminator loss approaches zero, it’s winning too easily—consider reducing its learning rate or capacity.
Finally, save checkpoints regularly. GAN training is unstable enough that your best results might occur mid-training rather than at the end.
This implementation provides a solid foundation for understanding GANs. For production use, consider more sophisticated architectures like DCGAN (convolutional layers), progressive growing, or Wasserstein GAN with gradient penalty for improved stability.