How to Implement Batch Normalization in PyTorch

Batch normalization revolutionized deep learning training when introduced in 2015. It addresses internal covariate shift—the phenomenon where the distribution of layer inputs changes during training...

Key Insights

  • Batch normalization normalizes layer inputs using batch statistics during training but switches to running statistics during inference—forgetting to call model.eval() is one of the most common bugs in production systems.
  • Place batch normalization after convolutional or linear layers but before activation functions for best results, though the after-activation placement can work in specific architectures.
  • Batch size matters significantly: batch normalization degrades with small batches (< 8 samples) because batch statistics become unreliable, making layer normalization or group normalization better alternatives for small-batch scenarios.

Introduction to Batch Normalization

Batch normalization revolutionized deep learning training when introduced in 2015. It addresses internal covariate shift—the phenomenon where the distribution of layer inputs changes during training as previous layers’ parameters update. This instability forces you to use small learning rates and careful initialization.

Batch normalization normalizes each feature to have zero mean and unit variance across a mini-batch, then applies learnable scale (γ) and shift (β) parameters:

y = γ * (x - μ_B) / √(σ²_B + ε) + β

Where μ_B is the batch mean, σ²_B is the batch variance, and ε is a small constant for numerical stability.

The benefits are tangible: you can use learning rates 10x higher, training converges faster, and networks become less sensitive to initialization. It also provides a regularization effect, sometimes eliminating the need for dropout.

Here’s a simple comparison showing the impact:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Network without batch norm
class NetWithoutBN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        return self.layers(x)

# Network with batch norm
class NetWithBN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        return self.layers(x)

# Training with higher learning rate becomes stable with BN
model_with_bn = NetWithBN()
optimizer = torch.optim.SGD(model_with_bn.parameters(), lr=0.1)  # 10x higher LR

Using PyTorch’s Built-in BatchNorm Layers

PyTorch provides three batch normalization variants:

  • nn.BatchNorm1d: For 2D inputs (batch_size, num_features) or 3D inputs (batch_size, num_features, length). Use for fully connected layers or 1D convolutions.
  • nn.BatchNorm2d: For 4D inputs (batch_size, channels, height, width). Use for 2D convolutions in CNNs.
  • nn.BatchNorm3d: For 5D inputs (batch_size, channels, depth, height, width). Use for 3D convolutions in video or medical imaging.

Key parameters you need to understand:

  • num_features: Number of features/channels to normalize (required)
  • eps: Small constant for numerical stability (default: 1e-5)
  • momentum: Factor for running statistics update (default: 0.1)
  • affine: Whether to learn γ and β parameters (default: True)
  • track_running_stats: Whether to track running mean/variance (default: True)

Here’s a practical CNN implementation:

class ConvNetWithBN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        
        self.features = nn.Sequential(
            # First conv block
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),  # num_features = number of channels
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Second conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            # Third conv block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(256 * 8 * 8, 512),
            nn.BatchNorm1d(512),  # Use BatchNorm1d for fully connected
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

Implementing Batch Normalization from Scratch

Understanding the internals helps debug issues and customize behavior. Here’s a complete implementation:

class CustomBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        
        # Running statistics (not parameters, won't be trained)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
    
    def forward(self, x):
        # x shape: (batch_size, num_features, height, width)
        
        if self.training:
            # Calculate batch statistics
            batch_mean = x.mean(dim=(0, 2, 3), keepdim=False)
            batch_var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=False)
            
            # Update running statistics
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + \
                                   self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + \
                                  self.momentum * batch_var
                self.num_batches_tracked += 1
            
            # Use batch statistics for normalization
            mean = batch_mean
            var = batch_var
        else:
            # Use running statistics during evaluation
            mean = self.running_mean
            var = self.running_var
        
        # Normalize: reshape for broadcasting
        mean = mean.view(1, -1, 1, 1)
        var = var.view(1, -1, 1, 1)
        gamma = self.gamma.view(1, -1, 1, 1)
        beta = self.beta.view(1, -1, 1, 1)
        
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        out = gamma * x_normalized + beta
        
        return out

This implementation shows the critical distinction between training and evaluation modes. During training, we normalize using the current batch’s statistics and update the running averages. During evaluation, we use the accumulated running statistics.

Batch Normalization Placement and Best Practices

The original paper placed batch normalization before activation functions, and this remains the recommended approach:

Conv/Linear → BatchNorm → Activation → Dropout (if used)

This ordering allows batch norm to normalize the pre-activation distribution. However, some modern architectures place it after activation with good results—experiment with your specific use case.

Here’s a ResNet-style residual block showing proper placement:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = self.shortcut(x)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = torch.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += identity
        out = torch.relu(out)
        
        return out

Important considerations:

  • Batch normalization provides regularization, so you may need less or no dropout
  • With batch norm, weight decay becomes more important for regularization
  • Use batch sizes of at least 16-32 for stable statistics; smaller batches make batch norm unreliable

Training vs. Evaluation Mode

This is where most bugs occur. Batch normalization behaves completely differently in training and evaluation:

model = ConvNetWithBN()

# Training mode: uses batch statistics
model.train()
batch = torch.randn(32, 3, 32, 32)
output_train = model(batch)

# Evaluation mode: uses running statistics
model.eval()
single_sample = torch.randn(1, 3, 32, 32)
output_eval = model(single_sample)

# Proper inference setup
def inference(model, data_loader):
    model.eval()  # Critical!
    predictions = []
    
    with torch.no_grad():  # Disable gradient computation
        for batch in data_loader:
            output = model(batch)
            predictions.append(output)
    
    return torch.cat(predictions)

# Testing with different modes
model = ConvNetWithBN()
test_input = torch.randn(16, 3, 32, 32)

model.train()
out1 = model(test_input)
out2 = model(test_input)
print(f"Training mode - outputs differ: {not torch.allclose(out1, out2)}")

model.eval()
out3 = model(test_input)
out4 = model(test_input)
print(f"Eval mode - outputs identical: {torch.allclose(out3, out4)}")

In training mode, outputs vary even with the same input because batch statistics differ. In evaluation mode, outputs are deterministic because running statistics are fixed.

Common Pitfalls and Debugging

Problem 1: Small batch sizes

With batch size 1, batch statistics are meaningless. Use group normalization or layer normalization instead:

# Instead of BatchNorm2d with small batches
nn.GroupNorm(num_groups=32, num_channels=256)  # Divides channels into groups
# or
nn.LayerNorm([256, 32, 32])  # Normalizes over channels, height, width

Problem 2: Forgetting to switch modes

This causes incorrect inference results:

# Wrong - still in training mode from training loop
predictions = model(test_data)

# Correct
model.eval()
with torch.no_grad():
    predictions = model(test_data)

Problem 3: Inspecting batch norm statistics

Debug issues by examining running statistics:

def inspect_batchnorm(model):
    for name, module in model.named_modules():
        if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            print(f"\n{name}:")
            print(f"  Running mean: {module.running_mean[:5]}")
            print(f"  Running var: {module.running_var[:5]}")
            print(f"  Num batches tracked: {module.num_batches_tracked}")
            print(f"  Gamma (scale): {module.weight[:5]}")
            print(f"  Beta (shift): {module.bias[:5]}")

model = ConvNetWithBN()
# Train for a few iterations...
inspect_batchnorm(model)

Problem 4: Distributed training

In distributed settings, use nn.SyncBatchNorm to synchronize statistics across GPUs:

model = ConvNetWithBN()
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(model)

Batch normalization is powerful but requires understanding its behavior. Place it after convolutions/linear layers, use adequate batch sizes, always call model.eval() during inference, and consider alternatives for small-batch scenarios. Master these principles, and you’ll train faster, more stable networks.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.