How to Implement Semantic Segmentation in PyTorch
Semantic segmentation is the task of classifying every pixel in an image into a predefined category. Unlike image classification, which assigns a single label to an entire image, or object detection,...
Key Insights
- Semantic segmentation requires pixel-wise classification and specialized architectures like U-Net that preserve spatial information through encoder-decoder structures with skip connections
- Dataset preparation is critical—masks must align perfectly with images, and augmentations must be applied identically to both inputs and targets
- Combining multiple loss functions (Cross-Entropy + Dice Loss) and using pre-trained encoders dramatically improves segmentation performance on limited datasets
Introduction to Semantic Segmentation
Semantic segmentation is the task of classifying every pixel in an image into a predefined category. Unlike image classification, which assigns a single label to an entire image, or object detection, which draws bounding boxes around objects, semantic segmentation produces a dense prediction map where each pixel gets its own class label.
This pixel-level understanding is essential for applications like autonomous driving (identifying roads, pedestrians, vehicles), medical imaging (tumor detection, organ segmentation), and satellite imagery analysis (land use classification). The challenge lies in maintaining spatial resolution while building sufficiently deep networks to capture semantic information.
In this article, we’ll build a complete semantic segmentation pipeline in PyTorch, from data loading through training to inference. We’ll implement U-Net, one of the most effective architectures for this task, and explore practical techniques that actually improve results.
Dataset Preparation and Data Loading
Segmentation datasets consist of image-mask pairs where the mask is a 2D array with integer class labels for each pixel. Your dataset structure should look like this:
dataset/
├── images/
│ ├── img_001.jpg
│ └── img_002.jpg
└── masks/
├── img_001.png
└── img_002.png
Here’s a robust Dataset class that handles loading and augmentation:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = sorted(os.listdir(image_dir))
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx])
image = np.array(Image.open(img_path).convert("RGB"))
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.int64)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
return image, mask
For augmentation, use the albumentations library—it applies transformations consistently to both images and masks:
train_transform = A.Compose([
A.Resize(256, 256),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
val_transform = A.Compose([
A.Resize(256, 256),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
train_dataset = SegmentationDataset('data/train/images', 'data/train/masks', train_transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
Building the U-Net Architecture
U-Net’s power comes from its symmetric encoder-decoder structure with skip connections that preserve fine-grained spatial information. The encoder downsamples to capture context, while the decoder upsamples to produce the segmentation map.
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=21):
super().__init__()
# Encoder
self.enc1 = DoubleConv(in_channels, 64)
self.enc2 = DoubleConv(64, 128)
self.enc3 = DoubleConv(128, 256)
self.enc4 = DoubleConv(256, 512)
# Bottleneck
self.bottleneck = DoubleConv(512, 1024)
# Decoder
self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.dec4 = DoubleConv(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec3 = DoubleConv(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = DoubleConv(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = DoubleConv(128, 64)
self.out = nn.Conv2d(64, num_classes, 1)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
# Encoder
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
# Bottleneck
b = self.bottleneck(self.pool(e4))
# Decoder with skip connections
d4 = self.upconv4(b)
d4 = torch.cat([d4, e4], dim=1)
d4 = self.dec4(d4)
d3 = self.upconv3(d4)
d3 = torch.cat([d3, e3], dim=1)
d3 = self.dec3(d3)
d2 = self.upconv2(d3)
d2 = torch.cat([d2, e2], dim=1)
d2 = self.dec2(d2)
d1 = self.upconv1(d2)
d1 = torch.cat([d1, e1], dim=1)
d1 = self.dec1(d1)
return self.out(d1)
Training Loop and Loss Functions
Cross-Entropy Loss works but struggles with class imbalance. Dice Loss directly optimizes the overlap between prediction and ground truth:
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
def forward(self, pred, target):
pred = torch.softmax(pred, dim=1)
target_one_hot = F.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()
intersection = (pred * target_one_hot).sum(dim=(2, 3))
union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
dice = (2. * intersection + self.smooth) / (union + self.smooth)
return 1 - dice.mean()
def calculate_iou(pred, target, num_classes):
ious = []
pred = pred.view(-1)
target = target.view(-1)
for cls in range(num_classes):
pred_inds = pred == cls
target_inds = target == cls
intersection = (pred_inds & target_inds).sum().float()
union = (pred_inds | target_inds).sum().float()
if union == 0:
ious.append(float('nan'))
else:
ious.append((intersection / union).item())
return np.nanmean(ious)
Here’s the complete training loop:
def train_model(model, train_loader, val_loader, num_epochs=50):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
ce_loss = nn.CrossEntropyLoss()
dice_loss = DiceLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
for epoch in range(num_epochs):
model.train()
train_loss = 0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = ce_loss(outputs, masks) + dice_loss(outputs, masks)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0
val_iou = 0
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = ce_loss(outputs, masks) + dice_loss(outputs, masks)
val_loss += loss.item()
preds = torch.argmax(outputs, dim=1)
val_iou += calculate_iou(preds, masks, num_classes=21)
val_loss /= len(val_loader)
val_iou /= len(val_loader)
scheduler.step(val_loss)
print(f'Epoch {epoch+1}: Train Loss={train_loss/len(train_loader):.4f}, Val Loss={val_loss:.4f}, Val IoU={val_iou:.4f}')
Inference and Visualization
For inference, handle single images and visualize results with color-coded masks:
import matplotlib.pyplot as plt
def predict_image(model, image_path, transform, device):
model.eval()
image = np.array(Image.open(image_path).convert("RGB"))
original = image.copy()
augmented = transform(image=image)
image_tensor = augmented['image'].unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
pred_mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()
return original, pred_mask
def visualize_prediction(image, mask, num_classes=21):
colors = plt.cm.get_cmap('tab20', num_classes)
colored_mask = colors(mask)[:, :, :3]
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[1].imshow(mask)
axes[1].set_title('Prediction')
axes[2].imshow(image)
axes[2].imshow(colored_mask, alpha=0.5)
axes[2].set_title('Overlay')
plt.show()
Performance Optimization and Best Practices
Use pre-trained encoders for better performance with less data:
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34",
encoder_weights="imagenet",
in_channels=3,
classes=21
)
Handle class imbalance with weighted loss:
def calculate_class_weights(dataloader, num_classes):
class_counts = torch.zeros(num_classes)
for _, masks in dataloader:
for c in range(num_classes):
class_counts[c] += (masks == c).sum()
weights = 1.0 / (class_counts + 1e-6)
weights = weights / weights.sum() * num_classes
return weights
weights = calculate_class_weights(train_loader, 21).to(device)
ce_loss = nn.CrossEntropyLoss(weight=weights)
Enable mixed precision training for faster training:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for images, masks in train_loader:
with autocast():
outputs = model(images)
loss = ce_loss(outputs, masks) + dice_loss(outputs, masks)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Semantic segmentation is computationally intensive, but these optimizations make it practical. Start with a pre-trained encoder, use combined losses, and always validate your augmentation pipeline by visualizing a few samples. The key to good segmentation is quality data and appropriate architecture choices—U-Net remains the gold standard for most applications.