How to Implement Image Classification in PyTorch
Image classification is the task of assigning a label to an image from a predefined set of categories. PyTorch has become the framework of choice for this task due to its pythonic design, excellent...
Key Insights
- PyTorch’s dynamic computational graph and intuitive API make it ideal for rapid prototyping and experimentation with image classification models
- Proper data preprocessing with transforms and DataLoaders is critical—normalization values must match your model’s training distribution, and batch size directly impacts training stability and speed
- Start with a simple CNN architecture before jumping to transfer learning; understanding the fundamentals of convolution layers, pooling, and forward passes will make debugging production issues significantly easier
Introduction & Setup
Image classification is the task of assigning a label to an image from a predefined set of categories. PyTorch has become the framework of choice for this task due to its pythonic design, excellent debugging capabilities, and seamless GPU acceleration. Unlike static graph frameworks, PyTorch builds computational graphs on-the-fly, making it natural to write and debug.
First, install the necessary dependencies. PyTorch installation varies by system and CUDA version, so check the official website for your specific command.
# For CPU-only (development)
pip install torch torchvision torchaudio
# For CUDA 11.8 (check pytorch.org for your version)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# Additional utilities
pip install matplotlib numpy pillow
Here are the essential imports you’ll need:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Loading and Preparing the Dataset
Data preparation is where most beginners stumble. The transforms you apply must match what your model expects, and the DataLoader configuration significantly impacts training performance.
CIFAR-10 is an excellent starting dataset—60,000 32x32 color images across 10 classes. Here’s how to load it with proper preprocessing:
# Define transforms for training and testing
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(), # Data augmentation
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
# Load datasets
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform_train
)
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform_test
)
# Create DataLoaders
trainloader = DataLoader(
trainset,
batch_size=128,
shuffle=True,
num_workers=2
)
testloader = DataLoader(
testset,
batch_size=100,
shuffle=False,
num_workers=2
)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
The normalization values aren’t arbitrary—they’re the mean and standard deviation of the CIFAR-10 dataset across RGB channels. Data augmentation (flipping, cropping) only applies to training data to prevent overfitting.
Building the CNN Model
A convolutional neural network extracts hierarchical features from images. Early layers detect edges, middle layers detect shapes, and deeper layers recognize complex patterns.
Here’s a practical CNN architecture:
class ImageClassifier(nn.Module):
def __init__(self, num_classes=10):
super(ImageClassifier, self).__init__()
# First convolutional block
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool1 = nn.MaxPool2d(2, 2)
# Second convolutional block
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(128)
self.pool2 = nn.MaxPool2d(2, 2)
# Third convolutional block
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn5 = nn.BatchNorm2d(256)
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.bn6 = nn.BatchNorm2d(256)
self.pool3 = nn.MaxPool2d(2, 2)
# Fully connected layers
self.fc1 = nn.Linear(256 * 4 * 4, 512)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
# Block 1
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.pool1(x)
# Block 2
x = self.relu(self.bn3(self.conv3(x)))
x = self.relu(self.bn4(self.conv4(x)))
x = self.pool2(x)
# Block 3
x = self.relu(self.bn5(self.conv5(x)))
x = self.relu(self.bn6(self.conv6(x)))
x = self.pool3(x)
# Flatten and classify
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# Instantiate model
model = ImageClassifier(num_classes=10).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
BatchNorm layers stabilize training, dropout prevents overfitting, and the architecture progressively reduces spatial dimensions while increasing channel depth.
Training the Model
The training loop is where your model learns. This implementation includes proper loss tracking and validation:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
num_epochs = 50
best_accuracy = 0.0
for epoch in range(num_epochs):
# Training phase
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, (images, labels) in enumerate(trainloader):
images, labels = images.to(device), labels.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if (i + 1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(trainloader)}], '
f'Loss: {running_loss/100:.4f}, Acc: {100*correct/total:.2f}%')
running_loss = 0.0
# Validation phase
model.eval()
val_correct = 0
val_total = 0
val_loss = 0.0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_accuracy = 100 * val_correct / val_total
avg_val_loss = val_loss / len(testloader)
print(f'Epoch [{epoch+1}/{num_epochs}] Validation Accuracy: {val_accuracy:.2f}%')
# Learning rate scheduling
scheduler.step(avg_val_loss)
# Save best model
if val_accuracy > best_accuracy:
best_accuracy = val_accuracy
torch.save(model.state_dict(), 'best_model.pth')
Evaluation and Testing
After training, evaluate your model systematically:
def evaluate_model(model, testloader, device):
model.eval()
correct = 0
total = 0
class_correct = [0] * 10
class_total = [0] * 10
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Per-class accuracy
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
print(f'Overall Accuracy: {100 * correct / total:.2f}%')
print('\nPer-class accuracy:')
for i in range(10):
print(f'{classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')
# Visualize predictions
def show_predictions(model, testloader, device, num_images=8):
model.eval()
images, labels = next(iter(testloader))
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
images = images.cpu()
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for idx, ax in enumerate(axes.flat):
img = images[idx].permute(1, 2, 0).numpy()
img = img * np.array([0.2023, 0.1994, 0.2010]) + np.array([0.4914, 0.4822, 0.4465])
img = np.clip(img, 0, 1)
ax.imshow(img)
ax.set_title(f'Pred: {classes[predicted[idx]]}\nTrue: {classes[labels[idx]]}')
ax.axis('off')
plt.tight_layout()
plt.show()
evaluate_model(model, testloader, device)
show_predictions(model, testloader, device)
Saving and Loading Models
Always save the state dictionary, not the entire model:
# Save model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
}, 'checkpoint.pth')
# Load model
checkpoint = torch.load('checkpoint.pth')
model = ImageClassifier(num_classes=10).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# For inference only
torch.save(model.state_dict(), 'model_weights.pth')
model.load_state_dict(torch.load('model_weights.pth'))
Next Steps and Optimization
Once you’ve mastered basic CNNs, leverage transfer learning with pre-trained models:
import torchvision.models as models
# Load pre-trained ResNet
model = models.resnet18(pretrained=True)
# Freeze early layers
for param in model.parameters():
param.requires_grad = False
# Replace final layer for your classes
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
# Only train the final layer initially
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
Transfer learning typically achieves 90%+ accuracy on CIFAR-10 with minimal training. Other optimizations include mixed precision training with torch.cuda.amp, gradient accumulation for larger effective batch sizes, and experimenting with optimizers like AdamW or SGD with momentum.
The foundation you’ve built here scales to any image classification problem—just swap the dataset and adjust the final layer’s output size.