How to Implement U-Net in PyTorch
U-Net emerged from a 2015 paper by Ronneberger et al. for biomedical image segmentation, where pixel-perfect predictions matter. Unlike classification networks that output a single label, U-Net...
Key Insights
- U-Net’s skip connections preserve spatial information lost during downsampling, making it superior to standard encoder-decoder architectures for dense prediction tasks
- The architecture is modular—building reusable double convolution, downsampling, and upsampling blocks makes the implementation clean and maintainable
- PyTorch’s
torch.catalong the channel dimension enables skip connections, whileConvTranspose2dhandles learnable upsampling better than simple interpolation
Introduction to U-Net Architecture
U-Net emerged from a 2015 paper by Ronneberger et al. for biomedical image segmentation, where pixel-perfect predictions matter. Unlike classification networks that output a single label, U-Net produces a segmentation map matching the input dimensions—every pixel gets classified.
The architecture follows an encoder-decoder pattern but with a critical difference: skip connections that concatenate feature maps from the encoder directly to corresponding decoder layers. This design solves a fundamental problem in semantic segmentation—downsampling loses spatial information that upsampling alone cannot recover. Skip connections provide the high-resolution details needed for precise localization.
U-Net excels when training data is limited. The architecture’s symmetric design and data augmentation strategies allow it to learn from small datasets, making it ideal for medical imaging where labeled data is expensive to obtain.
Understanding the U-Net Components
U-Net has three fundamental components working in concert:
The contracting path (encoder) captures context through successive convolutions and max pooling operations. Each downsampling step doubles the feature channels while halving spatial dimensions, building increasingly abstract representations.
The expansive path (decoder) enables precise localization through upsampling and convolutions. Each upsampling step halves feature channels while doubling spatial dimensions, reconstructing the output segmentation map.
Skip connections bridge the encoder and decoder at each resolution level. Before each upsampling operation, the decoder concatenates feature maps from the corresponding encoder layer. This provides the fine-grained spatial information lost during downsampling.
The data flow follows this pattern: input → encode → bottleneck → decode with skip connections → output segmentation map.
Implementing the Double Convolution Block
Every U-Net layer applies two consecutive 3×3 convolutions, each followed by batch normalization and ReLU activation. This double convolution block is the architecture’s fundamental building block.
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
"""Two consecutive convolution layers with BatchNorm and ReLU"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
The padding=1 maintains spatial dimensions through convolutions. Setting bias=False is standard practice when using batch normalization, which includes its own bias term. The inplace=True argument for ReLU saves memory by modifying tensors in place.
Building the Encoder (Contracting Path)
The encoder downsamples feature maps while increasing channel depth. Each downsampling step uses max pooling followed by a double convolution block.
class Down(nn.Module):
"""Downsampling with maxpool followed by double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
Max pooling with a kernel size of 2 halves both height and width dimensions. This aggressive downsampling expands the receptive field, allowing the network to capture broader context. The trade-off is spatial resolution, which skip connections will later recover.
Building the Decoder (Expansive Path)
The decoder reverses the encoder’s operations—upsampling to increase spatial dimensions while reducing channel depth. The critical addition is concatenating skip connections before applying convolutions.
class Up(nn.Module):
"""Upsampling followed by double conv with skip connection"""
def __init__(self, in_channels, out_channels, bilinear=False):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# Handle size mismatches from odd-sized inputs
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# Concatenate skip connection along channel dimension
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
The bilinear flag offers a choice between learnable transposed convolutions and simple bilinear upsampling. Transposed convolutions add learnable parameters but can create checkerboard artifacts. Bilinear upsampling is parameter-free and often produces smoother results.
The padding logic handles edge cases where input dimensions aren’t perfectly divisible by 16 (the total downsampling factor). The torch.cat operation concatenates along dimension 1 (channels), doubling the channel count before the double convolution reduces it.
Assembling the Complete U-Net Model
The complete U-Net combines encoder, bottleneck, and decoder into a cohesive architecture. The implementation follows the original paper’s structure with configurable channels.
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
# Encoder
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
# Decoder
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
# Output layer
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# Encoder with skip connections stored
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
# Decoder with skip connections
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
The encoder path stores intermediate feature maps (x1 through x4) for skip connections. The decoder receives both the upsampled features and the corresponding encoder features. The final 1×1 convolution maps to the desired number of output classes without changing spatial dimensions.
Training Example and Practical Usage
Here’s a complete training setup demonstrating U-Net for binary segmentation:
import torch.optim as optim
from torch.utils.data import DataLoader
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=1, bilinear=True).to(device)
# Loss and optimizer
criterion = nn.BCEWithLogitsLoss() # Binary segmentation
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Training loop
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
epoch_loss = 0
for batch_idx, (images, masks) in enumerate(dataloader):
images = images.to(device)
masks = masks.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, masks)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
# Example usage
# Assuming you have a dataset that returns (image, mask) pairs
# train_dataset = YourSegmentationDataset(...)
# train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# for epoch in range(num_epochs):
# loss = train_epoch(model, train_loader, criterion, optimizer, device)
# print(f'Epoch {epoch+1}, Loss: {loss:.4f}')
For multi-class segmentation, swap BCEWithLogitsLoss for CrossEntropyLoss and set n_classes to your number of classes. The model outputs raw logits—apply torch.sigmoid() for binary or torch.softmax() for multi-class predictions during inference.
Use learning rate scheduling and data augmentation for better results. The original paper emphasizes aggressive augmentation—random rotations, elastic deformations, and intensity shifts help the model generalize from limited training data.
U-Net’s modular design makes it easy to modify. Try adjusting the initial feature count, adding dropout for regularization, or experimenting with different normalization techniques. The architecture’s flexibility is why it remains a go-to choice for segmentation tasks nearly a decade after publication.