How to Implement ResNet in PyTorch
Deep neural networks should theoretically perform better as you add layers—more capacity means more representational power. In practice, networks deeper than 20-30 layers often performed *worse* than...
Key Insights
- ResNet’s skip connections allow gradients to flow directly through the network, enabling training of models with 100+ layers without degradation
- The basic block (two 3x3 convs) works for shallower ResNets while bottleneck blocks (1x1, 3x3, 1x1) reduce parameters in deeper variants
- PyTorch’s modular design makes ResNet implementation straightforward—build blocks first, then stack them with dimension-matching shortcuts
Understanding ResNet’s Core Innovation
Deep neural networks should theoretically perform better as you add layers—more capacity means more representational power. In practice, networks deeper than 20-30 layers often performed worse than shallower counterparts. This wasn’t overfitting; even training accuracy degraded.
ResNet solved this with residual learning. Instead of learning a direct mapping H(x), each layer learns a residual function F(x) = H(x) - x. The output becomes F(x) + x, where the “+x” is a skip connection that bypasses the layer entirely.
This simple addition has profound effects. Gradients can flow directly backward through skip connections, preventing vanishing gradients. If a layer isn’t useful, the network can learn F(x) ≈ 0, effectively removing that layer. This makes optimization easier and enables networks with 152+ layers.
The mathematical formulation is straightforward:
# Conceptual representation
output = F(x, weights) + x
Where F represents the stacked convolutional layers and x is the identity mapping from the input.
Building Basic and Bottleneck Blocks
ResNet uses two types of residual blocks depending on network depth. Let’s implement both.
The BasicBlock uses two 3x3 convolutions and is used in ResNet-18 and ResNet-34:
import torch
import torch.nn as nn
class BasicBlock(nn.Module):
expansion = 1 # Output channels multiplier
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock, self).__init__()
# First 3x3 convolution
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
# Second 3x3 convolution
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
# Main path
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# Skip connection
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
The BottleneckBlock reduces computational cost for deeper networks (ResNet-50/101/152) using 1x1 convolutions to compress and expand channels:
class BottleneckBlock(nn.Module):
expansion = 4 # Output channels = out_channels * 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BottleneckBlock, self).__init__()
# 1x1 convolution for dimension reduction
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
# 3x3 convolution (bottleneck)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# 1x1 convolution for dimension expansion
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
The downsample parameter handles dimension mismatches. When stride > 1 or channels change, we need a projection shortcut—typically a 1x1 convolution.
Assembling the Complete ResNet Architecture
Now we stack blocks into the full ResNet model. The architecture follows this pattern:
- Initial 7x7 convolution and max pooling
- Four stages of residual blocks with increasing channels (64, 128, 256, 512)
- Global average pooling and fully connected classifier
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
super(ResNet, self).__init__()
self.in_channels = 64
# Initial convolution
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Residual stages
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# Classification head
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, blocks, stride):
downsample = None
# Create projection shortcut if needed
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion)
)
layers = []
# First block may need downsampling
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
# Remaining blocks
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Factory functions for different ResNet variants
def resnet18(num_classes=1000):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
def resnet50(num_classes=1000):
return ResNet(BottleneckBlock, [3, 4, 6, 3], num_classes)
def resnet101(num_classes=1000):
return ResNet(BottleneckBlock, [3, 4, 23, 3], num_classes)
The layers parameter specifies blocks per stage. ResNet-50 uses [3, 4, 6, 3], totaling 50 layers when you count all convolutions.
Training ResNet from Scratch
Here’s a practical training setup for CIFAR-10:
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Data augmentation and normalization
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# Load CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True,
transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = resnet18(num_classes=10).to(device)
# Optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
# Training loop
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return running_loss / len(loader), 100. * correct / total
# Training
for epoch in range(200):
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
scheduler.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}: Loss={train_loss:.3f}, Acc={train_acc:.2f}%')
Use cosine annealing for smooth learning rate decay. For ImageNet-scale training, start with lr=0.1 and use step decay (divide by 10 at epochs 30, 60, 90).
Transfer Learning and Fine-Tuning
Pre-trained ResNets excel at transfer learning. Load ImageNet weights and adapt to your task:
from torchvision.models import resnet50, ResNet50_Weights
# Load pretrained model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Freeze all layers except the final classifier
for param in model.parameters():
param.requires_grad = False
# Replace classifier for your dataset (e.g., 10 classes)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
# Only train the new classifier
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
# For fine-tuning all layers after initial training:
# Unfreeze and use smaller learning rate
for param in model.parameters():
param.requires_grad = True
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
Evaluation is straightforward:
def evaluate(model, loader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return 100. * correct / total
test_acc = evaluate(model, test_loader, device)
print(f'Test Accuracy: {test_acc:.2f}%')
ResNet’s architecture is elegant and effective. The skip connections solve a fundamental optimization problem, making deep networks trainable. Start with pretrained models for most tasks—training from scratch requires significant compute and data. When you do train from scratch, use proper data augmentation, learning rate schedules, and give it enough epochs. ResNet-50 typically achieves 76%+ top-1 accuracy on ImageNet with proper training.