How to Implement VGG in PyTorch

VGG (Visual Geometry Group) revolutionized deep learning in 2014 by demonstrating that network depth significantly impacts performance. The architecture's elegance lies in its simplicity: stack small...

Key Insights

  • VGG’s uniform architecture using only 3x3 convolutions makes it straightforward to implement and understand, serving as an excellent learning model for deep CNNs
  • Building VGG from scratch requires creating reusable convolutional blocks that stack multiple conv layers before pooling, following a consistent pattern throughout the network
  • For production use, leverage pretrained VGG models from torchvision and apply transfer learning rather than training from scratch, saving weeks of computation time

Understanding VGG Architecture

VGG (Visual Geometry Group) revolutionized deep learning in 2014 by demonstrating that network depth significantly impacts performance. The architecture’s elegance lies in its simplicity: stack small 3x3 convolutional filters repeatedly, interspersed with max pooling layers. VGG-16 contains 16 weight layers (13 convolutional, 3 fully connected), while VGG-19 extends this to 19 layers.

The key insight of VGG was that two 3x3 conv layers have the same effective receptive field as one 5x5 layer, but with fewer parameters and more non-linearity. This design choice makes VGG both powerful and conceptually clean—perfect for understanding how modern CNNs work.

VGG Building Blocks

VGG follows a consistent pattern: convolutional blocks followed by max pooling. Each block contains 2-4 convolutional layers with 3x3 kernels, ReLU activation, and same padding. After each block, a 2x2 max pooling layer halves the spatial dimensions while the number of channels doubles.

Here’s the VGG-16 configuration:

# VGG-16 layer configuration
# Format: [num_filters, num_conv_layers_in_block]
VGG16_CONFIG = [
    (64, 2),   # Block 1: 2 conv layers, 64 filters
    (128, 2),  # Block 2: 2 conv layers, 128 filters
    (256, 3),  # Block 3: 3 conv layers, 256 filters
    (512, 3),  # Block 4: 3 conv layers, 512 filters
    (512, 3),  # Block 5: 3 conv layers, 512 filters
]
# Followed by 3 fully connected layers: 4096 -> 4096 -> 1000

The spatial dimension progression for 224x224 input: 224 → 112 → 56 → 28 → 14 → 7, while channels go: 3 → 64 → 128 → 256 → 512 → 512.

Implementing the Convolutional Block

The core building block is a function that creates multiple conv layers followed by max pooling. This reusable component eliminates code duplication and makes the architecture crystal clear.

import torch
import torch.nn as nn

def make_vgg_block(in_channels, out_channels, num_convs):
    """
    Create a VGG convolutional block.
    
    Args:
        in_channels: Number of input channels
        out_channels: Number of output channels (filters)
        num_convs: Number of conv layers in this block
    
    Returns:
        nn.Sequential containing conv layers and max pooling
    """
    layers = []
    
    for i in range(num_convs):
        layers.append(nn.Conv2d(
            in_channels if i == 0 else out_channels,
            out_channels,
            kernel_size=3,
            padding=1  # Same padding to preserve spatial dimensions
        ))
        layers.append(nn.ReLU(inplace=True))
    
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    
    return nn.Sequential(*layers)

# Example usage
block1 = make_vgg_block(3, 64, 2)    # First block: RGB -> 64 channels
block2 = make_vgg_block(64, 128, 2)  # Second block: 64 -> 128 channels

This function handles the first conv layer’s input channel mismatch while keeping subsequent layers uniform. The inplace=True ReLU saves memory by modifying tensors in-place.

Building the Complete VGG-16 Model

Now we assemble the full network by stacking blocks and adding fully connected layers. This implementation follows the original VGG paper precisely.

class VGG16(nn.Module):
    def __init__(self, num_classes=1000):
        super(VGG16, self).__init__()
        
        # Feature extraction layers
        self.features = nn.Sequential(
            # Block 1: 64 filters, 2 conv layers
            *self._make_block(3, 64, 2),
            # Block 2: 128 filters, 2 conv layers
            *self._make_block(64, 128, 2),
            # Block 3: 256 filters, 3 conv layers
            *self._make_block(128, 256, 3),
            # Block 4: 512 filters, 3 conv layers
            *self._make_block(256, 512, 3),
            # Block 5: 512 filters, 3 conv layers
            *self._make_block(512, 512, 3),
        )
        
        # Adaptive pooling to handle different input sizes
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        
        # Fully connected classifier
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes),
        )
        
        # Initialize weights
        self._initialize_weights()
    
    def _make_block(self, in_channels, out_channels, num_convs):
        """Helper method to create a VGG block."""
        layers = []
        for i in range(num_convs):
            layers.append(nn.Conv2d(
                in_channels if i == 0 else out_channels,
                out_channels,
                kernel_size=3,
                padding=1
            ))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        return layers
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def _initialize_weights(self):
        """Initialize weights using He initialization."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

# Create model instance
model = VGG16(num_classes=10)  # For CIFAR-10
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

Training VGG on CIFAR-10

Here’s a practical training example using CIFAR-10, which is more accessible than ImageNet for learning purposes.

import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Data preprocessing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),  # VGG expects 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

# Setup model, loss, optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VGG16(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training loop
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.3f}, Acc: {100.*correct/total:.2f}%')
    
    return running_loss / len(loader), 100. * correct / total

# Train for multiple epochs
for epoch in range(10):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    scheduler.step()
    print(f'Epoch {epoch}: Loss={train_loss:.3f}, Acc={train_acc:.2f}%')
    
    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': train_loss,
    }, f'vgg16_checkpoint_epoch_{epoch}.pth')

Transfer Learning with Pretrained VGG

Training VGG from scratch requires massive computational resources. Instead, use pretrained weights and fine-tune for your specific task.

from torchvision import models

# Load pretrained VGG16
model = models.vgg16(pretrained=True)

# Freeze all feature layers
for param in model.features.parameters():
    param.requires_grad = False

# Replace classifier for 10 classes (CIFAR-10)
num_features = model.classifier[0].in_features
model.classifier = nn.Sequential(
    nn.Linear(num_features, 4096),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(4096, 1024),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(1024, 10),
)

# Move to device
model = model.to(device)

# Only train classifier parameters
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

# Now train as before - much faster convergence!

This approach trains only the classifier layers (a few million parameters) while leveraging features learned from ImageNet (millions of images). You’ll see good results in just a few epochs.

Performance Optimization Tips

Memory Management: VGG is memory-hungry. For large batches, use gradient checkpointing or reduce batch size. Consider mixed precision training with torch.cuda.amp to halve memory usage.

Batch Normalization: The original VGG doesn’t use batch norm, but adding it after each conv layer significantly improves training stability and speed. Replace nn.ReLU(inplace=True) with nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True).

Modern Alternatives: While VGG is excellent for learning, ResNet, EfficientNet, or Vision Transformers achieve better accuracy with fewer parameters. Use VGG when you need a simple, interpretable baseline or when working with pretrained features.

Input Size: VGG expects 224x224 images. For smaller datasets like CIFAR-10 (32x32), either resize images (adds computation) or modify the architecture to use fewer pooling layers.

The beauty of VGG lies in its simplicity. Understanding this architecture provides a solid foundation for exploring more complex CNNs. Start here, master the concepts, then move to modern architectures.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.