How to Implement Batch Normalization in PyTorch
Batch normalization revolutionized deep learning training when introduced in 2015. It addresses internal covariate shift—the phenomenon where the distribution of layer inputs changes during training...
Key Insights
- Batch normalization normalizes layer inputs using batch statistics during training but switches to running statistics during inference—forgetting to call
model.eval()is one of the most common bugs in production systems. - Place batch normalization after convolutional or linear layers but before activation functions for best results, though the after-activation placement can work in specific architectures.
- Batch size matters significantly: batch normalization degrades with small batches (< 8 samples) because batch statistics become unreliable, making layer normalization or group normalization better alternatives for small-batch scenarios.
Introduction to Batch Normalization
Batch normalization revolutionized deep learning training when introduced in 2015. It addresses internal covariate shift—the phenomenon where the distribution of layer inputs changes during training as previous layers’ parameters update. This instability forces you to use small learning rates and careful initialization.
Batch normalization normalizes each feature to have zero mean and unit variance across a mini-batch, then applies learnable scale (γ) and shift (β) parameters:
y = γ * (x - μ_B) / √(σ²_B + ε) + β
Where μ_B is the batch mean, σ²_B is the batch variance, and ε is a small constant for numerical stability.
The benefits are tangible: you can use learning rates 10x higher, training converges faster, and networks become less sensitive to initialization. It also provides a regularization effect, sometimes eliminating the need for dropout.
Here’s a simple comparison showing the impact:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# Network without batch norm
class NetWithoutBN(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
return self.layers(x)
# Network with batch norm
class NetWithBN(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(784, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, 10)
)
def forward(self, x):
return self.layers(x)
# Training with higher learning rate becomes stable with BN
model_with_bn = NetWithBN()
optimizer = torch.optim.SGD(model_with_bn.parameters(), lr=0.1) # 10x higher LR
Using PyTorch’s Built-in BatchNorm Layers
PyTorch provides three batch normalization variants:
nn.BatchNorm1d: For 2D inputs (batch_size, num_features) or 3D inputs (batch_size, num_features, length). Use for fully connected layers or 1D convolutions.nn.BatchNorm2d: For 4D inputs (batch_size, channels, height, width). Use for 2D convolutions in CNNs.nn.BatchNorm3d: For 5D inputs (batch_size, channels, depth, height, width). Use for 3D convolutions in video or medical imaging.
Key parameters you need to understand:
num_features: Number of features/channels to normalize (required)eps: Small constant for numerical stability (default: 1e-5)momentum: Factor for running statistics update (default: 0.1)affine: Whether to learn γ and β parameters (default: True)track_running_stats: Whether to track running mean/variance (default: True)
Here’s a practical CNN implementation:
class ConvNetWithBN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
# First conv block
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64), # num_features = number of channels
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Second conv block
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Third conv block
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
self.classifier = nn.Sequential(
nn.Linear(256 * 8 * 8, 512),
nn.BatchNorm1d(512), # Use BatchNorm1d for fully connected
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
Implementing Batch Normalization from Scratch
Understanding the internals helps debug issues and customize behavior. Here’s a complete implementation:
class CustomBatchNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# Learnable parameters
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# Running statistics (not parameters, won't be trained)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
def forward(self, x):
# x shape: (batch_size, num_features, height, width)
if self.training:
# Calculate batch statistics
batch_mean = x.mean(dim=(0, 2, 3), keepdim=False)
batch_var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=False)
# Update running statistics
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + \
self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var + \
self.momentum * batch_var
self.num_batches_tracked += 1
# Use batch statistics for normalization
mean = batch_mean
var = batch_var
else:
# Use running statistics during evaluation
mean = self.running_mean
var = self.running_var
# Normalize: reshape for broadcasting
mean = mean.view(1, -1, 1, 1)
var = var.view(1, -1, 1, 1)
gamma = self.gamma.view(1, -1, 1, 1)
beta = self.beta.view(1, -1, 1, 1)
x_normalized = (x - mean) / torch.sqrt(var + self.eps)
out = gamma * x_normalized + beta
return out
This implementation shows the critical distinction between training and evaluation modes. During training, we normalize using the current batch’s statistics and update the running averages. During evaluation, we use the accumulated running statistics.
Batch Normalization Placement and Best Practices
The original paper placed batch normalization before activation functions, and this remains the recommended approach:
Conv/Linear → BatchNorm → Activation → Dropout (if used)
This ordering allows batch norm to normalize the pre-activation distribution. However, some modern architectures place it after activation with good results—experiment with your specific use case.
Here’s a ResNet-style residual block showing proper placement:
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
# Shortcut connection
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = torch.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = torch.relu(out)
return out
Important considerations:
- Batch normalization provides regularization, so you may need less or no dropout
- With batch norm, weight decay becomes more important for regularization
- Use batch sizes of at least 16-32 for stable statistics; smaller batches make batch norm unreliable
Training vs. Evaluation Mode
This is where most bugs occur. Batch normalization behaves completely differently in training and evaluation:
model = ConvNetWithBN()
# Training mode: uses batch statistics
model.train()
batch = torch.randn(32, 3, 32, 32)
output_train = model(batch)
# Evaluation mode: uses running statistics
model.eval()
single_sample = torch.randn(1, 3, 32, 32)
output_eval = model(single_sample)
# Proper inference setup
def inference(model, data_loader):
model.eval() # Critical!
predictions = []
with torch.no_grad(): # Disable gradient computation
for batch in data_loader:
output = model(batch)
predictions.append(output)
return torch.cat(predictions)
# Testing with different modes
model = ConvNetWithBN()
test_input = torch.randn(16, 3, 32, 32)
model.train()
out1 = model(test_input)
out2 = model(test_input)
print(f"Training mode - outputs differ: {not torch.allclose(out1, out2)}")
model.eval()
out3 = model(test_input)
out4 = model(test_input)
print(f"Eval mode - outputs identical: {torch.allclose(out3, out4)}")
In training mode, outputs vary even with the same input because batch statistics differ. In evaluation mode, outputs are deterministic because running statistics are fixed.
Common Pitfalls and Debugging
Problem 1: Small batch sizes
With batch size 1, batch statistics are meaningless. Use group normalization or layer normalization instead:
# Instead of BatchNorm2d with small batches
nn.GroupNorm(num_groups=32, num_channels=256) # Divides channels into groups
# or
nn.LayerNorm([256, 32, 32]) # Normalizes over channels, height, width
Problem 2: Forgetting to switch modes
This causes incorrect inference results:
# Wrong - still in training mode from training loop
predictions = model(test_data)
# Correct
model.eval()
with torch.no_grad():
predictions = model(test_data)
Problem 3: Inspecting batch norm statistics
Debug issues by examining running statistics:
def inspect_batchnorm(model):
for name, module in model.named_modules():
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
print(f"\n{name}:")
print(f" Running mean: {module.running_mean[:5]}")
print(f" Running var: {module.running_var[:5]}")
print(f" Num batches tracked: {module.num_batches_tracked}")
print(f" Gamma (scale): {module.weight[:5]}")
print(f" Beta (shift): {module.bias[:5]}")
model = ConvNetWithBN()
# Train for a few iterations...
inspect_batchnorm(model)
Problem 4: Distributed training
In distributed settings, use nn.SyncBatchNorm to synchronize statistics across GPUs:
model = ConvNetWithBN()
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(model)
Batch normalization is powerful but requires understanding its behavior. Place it after convolutions/linear layers, use adequate batch sizes, always call model.eval() during inference, and consider alternatives for small-batch scenarios. Master these principles, and you’ll train faster, more stable networks.