How to Implement Custom Loss Functions in PyTorch

Loss functions quantify how wrong your model's predictions are, providing the optimization signal that drives learning. PyTorch ships with standard losses like `nn.CrossEntropyLoss()`,...

Key Insights

  • Custom loss functions in PyTorch can be implemented as simple Python functions for quick prototyping or as nn.Module subclasses for production code with learnable parameters
  • Always verify gradient computation with torch.autograd.gradcheck() to catch numerical instabilities and implementation errors before training
  • Vectorize operations and avoid Python loops in custom losses—a poorly optimized loss function can become the bottleneck in your training pipeline

Introduction to Loss Functions in PyTorch

Loss functions quantify how wrong your model’s predictions are, providing the optimization signal that drives learning. PyTorch ships with standard losses like nn.CrossEntropyLoss(), nn.MSELoss(), and nn.BCELoss(), which cover most common scenarios. But real-world problems often demand specialized loss functions: you might need domain-specific penalties, multi-task objectives, or custom weighting schemes that standard implementations don’t provide.

Here’s the difference between using a built-in loss and implementing your own:

import torch
import torch.nn as nn

# Built-in approach
predictions = torch.randn(32, 1)
targets = torch.randn(32, 1)
loss_fn = nn.MSELoss()
loss = loss_fn(predictions, targets)

# Manual implementation - functionally identical
manual_loss = ((predictions - targets) ** 2).mean()

print(f"Built-in: {loss.item():.4f}, Manual: {manual_loss.item():.4f}")

Both approaches compute the same value, but custom implementations let you inject domain knowledge, handle edge cases differently, or combine multiple objectives in ways that reflect your problem’s unique requirements.

Anatomy of a PyTorch Loss Function

A PyTorch loss function is just a differentiable operation that takes predictions and targets as input and returns a scalar (or tensor, depending on reduction). The critical requirement: PyTorch must be able to compute gradients through it via automatic differentiation.

Let’s examine the structure using MSE loss:

def mse_loss_detailed(predictions, targets, reduction='mean'):
    """
    predictions: (batch_size, ...) - model outputs
    targets: (batch_size, ...) - ground truth values
    reduction: 'mean' | 'sum' | 'none'
    """
    # Element-wise squared difference
    squared_diff = (predictions - targets) ** 2  # shape: (batch_size, ...)
    
    if reduction == 'none':
        return squared_diff  # Return per-sample losses
    elif reduction == 'sum':
        return squared_diff.sum()  # Single scalar
    else:  # 'mean'
        return squared_diff.mean()  # Single scalar (default)

# Usage
preds = torch.randn(16, 10, requires_grad=True)
targs = torch.randn(16, 10)

loss = mse_loss_detailed(preds, targs)
loss.backward()  # Gradients flow back through the operations
print(f"Loss: {loss.item():.4f}, Gradient shape: {preds.grad.shape}")

The reduction parameter controls output shape: 'none' returns per-sample losses useful for weighted sampling, while 'mean' or 'sum' produce the scalar needed for backpropagation.

Creating a Basic Custom Loss Function

Start with a simple function, then promote it to a class when you need state or parameters. Let’s implement a weighted MSE loss where certain samples matter more:

# Approach 1: Simple function
def weighted_mse_loss(predictions, targets, weights):
    """Apply per-sample weights to MSE loss."""
    squared_diff = (predictions - targets) ** 2
    weighted_diff = squared_diff * weights.unsqueeze(-1)
    return weighted_diff.mean()

# Usage
preds = torch.randn(8, 5)
targs = torch.randn(8, 5)
weights = torch.tensor([1.0, 1.0, 2.0, 2.0, 0.5, 0.5, 1.5, 1.5])

loss = weighted_mse_loss(preds, targs, weights)

For production code, wrap it in nn.Module to integrate cleanly with PyTorch’s ecosystem:

class WeightedMSELoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
    
    def forward(self, predictions, targets, weights):
        squared_diff = (predictions - targets) ** 2
        weighted_diff = squared_diff * weights.unsqueeze(-1)
        
        if self.reduction == 'mean':
            return weighted_diff.mean()
        elif self.reduction == 'sum':
            return weighted_diff.sum()
        else:
            return weighted_diff

# Usage
loss_fn = WeightedMSELoss()
loss = loss_fn(preds, targs, weights)

The nn.Module approach provides better encapsulation and makes your loss function a first-class citizen in PyTorch’s architecture.

Advanced Custom Loss Implementations

Real applications often require combining multiple objectives or implementing research paper losses. Here’s a multi-task loss for simultaneous classification and regression:

class MultiTaskLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()
    
    def forward(self, class_preds, class_targets, reg_preds, reg_targets):
        cls_loss = self.ce_loss(class_preds, class_targets)
        reg_loss = self.mse_loss(reg_preds, reg_targets)
        return self.alpha * cls_loss + (1 - self.alpha) * reg_loss

Focal loss, designed to handle class imbalance by down-weighting easy examples, demonstrates a more complex implementation:

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        """
        inputs: (N, C) - raw logits
        targets: (N,) - class indices
        """
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        p_t = torch.exp(-ce_loss)  # Probability of true class
        focal_weight = (1 - p_t) ** self.gamma
        focal_loss = self.alpha * focal_weight * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

# Example usage
focal_loss_fn = FocalLoss(alpha=0.25, gamma=2.0)
logits = torch.randn(32, 10)
labels = torch.randint(0, 10, (32,))
loss = focal_loss_fn(logits, labels)

Testing and Debugging Custom Losses

Never deploy a custom loss without verification. Use gradcheck to numerically validate gradient computation:

from torch.autograd import gradcheck

def test_custom_loss():
    # Create loss function
    loss_fn = WeightedMSELoss()
    
    # Use double precision for numerical stability in gradcheck
    preds = torch.randn(4, 3, dtype=torch.double, requires_grad=True)
    targs = torch.randn(4, 3, dtype=torch.double)
    weights = torch.ones(4, dtype=torch.double)
    
    # gradcheck returns True if analytical gradients match numerical
    test_passed = gradcheck(
        lambda x: loss_fn(x, targs, weights),
        preds,
        eps=1e-6,
        atol=1e-4
    )
    
    print(f"Gradient check passed: {test_passed}")

test_custom_loss()

Also test edge cases: zero predictions, identical inputs/targets, extreme values, and empty batches. Add assertions to catch invalid inputs:

class RobustCustomLoss(nn.Module):
    def forward(self, predictions, targets):
        assert predictions.shape == targets.shape, \
            f"Shape mismatch: {predictions.shape} vs {targets.shape}"
        assert not torch.isnan(predictions).any(), "NaN in predictions"
        # ... loss computation

Integration with Training Loops

Custom losses integrate seamlessly into standard training code:

import torch.optim as optim

# Setup
model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1))
loss_fn = WeightedMSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Dummy data
train_data = [(torch.randn(32, 10), torch.randn(32, 1), torch.ones(32)) 
              for _ in range(100)]

# Training loop
for epoch in range(10):
    epoch_loss = 0.0
    for batch_x, batch_y, batch_weights in train_data:
        optimizer.zero_grad()
        
        predictions = model(batch_x)
        loss = loss_fn(predictions, batch_y, batch_weights)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_data):.4f}")

Performance Considerations

Inefficient loss functions bottleneck training. Always vectorize operations and avoid Python loops:

# SLOW: Python loop
def slow_weighted_loss(preds, targets, weights):
    loss = 0.0
    for i in range(len(preds)):
        loss += weights[i] * ((preds[i] - targets[i]) ** 2).sum()
    return loss / len(preds)

# FAST: Vectorized
def fast_weighted_loss(preds, targets, weights):
    squared_diff = (preds - targets) ** 2
    return (squared_diff * weights.unsqueeze(-1)).mean()

# Benchmark
preds = torch.randn(1000, 100).cuda()
targets = torch.randn(1000, 100).cuda()
weights = torch.rand(1000).cuda()

import time
start = time.time()
for _ in range(100):
    slow_weighted_loss(preds, targets, weights)
print(f"Slow: {time.time() - start:.3f}s")

start = time.time()
for _ in range(100):
    fast_weighted_loss(preds, targets, weights)
print(f"Fast: {time.time() - start:.3f}s")

The vectorized version typically runs 10-100x faster. Additional tips: use in-place operations when safe (+=, *=), leverage PyTorch’s functional API for common operations, and profile with torch.autograd.profiler to identify bottlenecks.

Custom loss functions unlock PyTorch’s full potential for specialized problems. Start simple with functions, graduate to nn.Module classes for complex logic, always validate gradients, and optimize for performance. Your loss function is the signal that guides learning—make it count.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.