How to Implement Image Classification in PyTorch

Image classification is the task of assigning a label to an image from a predefined set of categories. PyTorch has become the framework of choice for this task due to its pythonic design, excellent...

Key Insights

  • PyTorch’s dynamic computational graph and intuitive API make it ideal for rapid prototyping and experimentation with image classification models
  • Proper data preprocessing with transforms and DataLoaders is critical—normalization values must match your model’s training distribution, and batch size directly impacts training stability and speed
  • Start with a simple CNN architecture before jumping to transfer learning; understanding the fundamentals of convolution layers, pooling, and forward passes will make debugging production issues significantly easier

Introduction & Setup

Image classification is the task of assigning a label to an image from a predefined set of categories. PyTorch has become the framework of choice for this task due to its pythonic design, excellent debugging capabilities, and seamless GPU acceleration. Unlike static graph frameworks, PyTorch builds computational graphs on-the-fly, making it natural to write and debug.

First, install the necessary dependencies. PyTorch installation varies by system and CUDA version, so check the official website for your specific command.

# For CPU-only (development)
pip install torch torchvision torchaudio

# For CUDA 11.8 (check pytorch.org for your version)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Additional utilities
pip install matplotlib numpy pillow

Here are the essential imports you’ll need:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Loading and Preparing the Dataset

Data preparation is where most beginners stumble. The transforms you apply must match what your model expects, and the DataLoader configuration significantly impacts training performance.

CIFAR-10 is an excellent starting dataset—60,000 32x32 color images across 10 classes. Here’s how to load it with proper preprocessing:

# Define transforms for training and testing
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Data augmentation
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010))
])

# Load datasets
trainset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transform_train
)

testset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=transform_test
)

# Create DataLoaders
trainloader = DataLoader(
    trainset, 
    batch_size=128,
    shuffle=True, 
    num_workers=2
)

testloader = DataLoader(
    testset, 
    batch_size=100,
    shuffle=False, 
    num_workers=2
)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

The normalization values aren’t arbitrary—they’re the mean and standard deviation of the CIFAR-10 dataset across RGB channels. Data augmentation (flipping, cropping) only applies to training data to prevent overfitting.

Building the CNN Model

A convolutional neural network extracts hierarchical features from images. Early layers detect edges, middle layers detect shapes, and deeper layers recognize complex patterns.

Here’s a practical CNN architecture:

class ImageClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super(ImageClassifier, self).__init__()
        
        # First convolutional block
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        # Second convolutional block
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        # Third convolutional block
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        # Block 1
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        
        # Block 2
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        
        # Block 3
        x = self.relu(self.bn5(self.conv5(x)))
        x = self.relu(self.bn6(self.conv6(x)))
        x = self.pool3(x)
        
        # Flatten and classify
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Instantiate model
model = ImageClassifier(num_classes=10).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

BatchNorm layers stabilize training, dropout prevents overfitting, and the architecture progressively reduces spatial dimensions while increasing channel depth.

Training the Model

The training loop is where your model learns. This implementation includes proper loss tracking and validation:

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

num_epochs = 50
best_accuracy = 0.0

for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (images, labels) in enumerate(trainloader):
        images, labels = images.to(device), labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(trainloader)}], '
                  f'Loss: {running_loss/100:.4f}, Acc: {100*correct/total:.2f}%')
            running_loss = 0.0
    
    # Validation phase
    model.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0.0
    
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_accuracy = 100 * val_correct / val_total
    avg_val_loss = val_loss / len(testloader)
    
    print(f'Epoch [{epoch+1}/{num_epochs}] Validation Accuracy: {val_accuracy:.2f}%')
    
    # Learning rate scheduling
    scheduler.step(avg_val_loss)
    
    # Save best model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')

Evaluation and Testing

After training, evaluate your model systematically:

def evaluate_model(model, testloader, device):
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * 10
    class_total = [0] * 10
    
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Per-class accuracy
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    
    print(f'Overall Accuracy: {100 * correct / total:.2f}%')
    print('\nPer-class accuracy:')
    for i in range(10):
        print(f'{classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')

# Visualize predictions
def show_predictions(model, testloader, device, num_images=8):
    model.eval()
    images, labels = next(iter(testloader))
    images, labels = images.to(device), labels.to(device)
    
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)
    
    images = images.cpu()
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for idx, ax in enumerate(axes.flat):
        img = images[idx].permute(1, 2, 0).numpy()
        img = img * np.array([0.2023, 0.1994, 0.2010]) + np.array([0.4914, 0.4822, 0.4465])
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(f'Pred: {classes[predicted[idx]]}\nTrue: {classes[labels[idx]]}')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

evaluate_model(model, testloader, device)
show_predictions(model, testloader, device)

Saving and Loading Models

Always save the state dictionary, not the entire model:

# Save model
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'accuracy': best_accuracy,
}, 'checkpoint.pth')

# Load model
checkpoint = torch.load('checkpoint.pth')
model = ImageClassifier(num_classes=10).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# For inference only
torch.save(model.state_dict(), 'model_weights.pth')
model.load_state_dict(torch.load('model_weights.pth'))

Next Steps and Optimization

Once you’ve mastered basic CNNs, leverage transfer learning with pre-trained models:

import torchvision.models as models

# Load pre-trained ResNet
model = models.resnet18(pretrained=True)

# Freeze early layers
for param in model.parameters():
    param.requires_grad = False

# Replace final layer for your classes
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

# Only train the final layer initially
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

Transfer learning typically achieves 90%+ accuracy on CIFAR-10 with minimal training. Other optimizations include mixed precision training with torch.cuda.amp, gradient accumulation for larger effective batch sizes, and experimenting with optimizers like AdamW or SGD with momentum.

The foundation you’ve built here scales to any image classification problem—just swap the dataset and adjust the final layer’s output size.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.