How to Use TensorBoard with PyTorch

TensorBoard started as TensorFlow's visualization toolkit but has become the de facto standard for monitoring deep learning experiments across frameworks. For PyTorch developers, it provides...

Key Insights

  • TensorBoard’s SummaryWriter integrates seamlessly with PyTorch to track metrics, visualize architectures, and compare experiments without writing custom plotting code
  • Proper logging hygiene—including consistent naming conventions, regular writer flushing, and organized run directories—prevents data loss and makes experiment comparison straightforward
  • Beyond basic scalar logging, TensorBoard’s advanced features like embedding projections and image logging can dramatically speed up debugging of computer vision and NLP models

Introduction & Setup

TensorBoard started as TensorFlow’s visualization toolkit but has become the de facto standard for monitoring deep learning experiments across frameworks. For PyTorch developers, it provides real-time training visualization, model architecture inspection, and experiment comparison without maintaining separate plotting scripts.

The value proposition is simple: instead of printing metrics to console or writing custom matplotlib code, you log everything to TensorBoard and get interactive, web-based visualizations automatically. This becomes critical when running multiple experiments with different hyperparameters or debugging why your model isn’t converging.

Installation is straightforward:

pip install torch torchvision tensorboard

The basic imports you’ll need:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

Launch TensorBoard from your terminal to view logged data:

tensorboard --logdir=runs

This starts a local web server (typically at localhost:6006) that displays all experiments logged to the runs directory.

Logging Scalars and Metrics

Scalar logging is your bread and butter for tracking training progress. The SummaryWriter class handles all logging operations, and you should create one instance per experiment run.

Here’s a complete training loop with proper metric logging:

# Create writer with descriptive run name
writer = SummaryWriter(f'runs/experiment_lr_{0.001}_batch_{32}')

# Simple neural network for demonstration
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 10)
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
n_epochs = 10
global_step = 0

for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data.view(data.size(0), -1))
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # Accumulate metrics
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # Log every 100 batches
        if batch_idx % 100 == 0:
            writer.add_scalar('Loss/train', loss.item(), global_step)
            global_step += 1
    
    # Log epoch-level metrics
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    
    writer.add_scalar('Loss/train_epoch', epoch_loss, epoch)
    writer.add_scalar('Accuracy/train', epoch_acc, epoch)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            output = model(data.view(data.size(0), -1))
            loss = criterion(output, target)
            val_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            val_total += target.size(0)
            val_correct += (predicted == target).sum().item()
    
    val_loss /= len(val_loader)
    val_acc = 100 * val_correct / val_total
    
    writer.add_scalar('Loss/validation', val_loss, epoch)
    writer.add_scalar('Accuracy/validation', val_acc, epoch)
    
    print(f'Epoch {epoch}: Train Loss={epoch_loss:.4f}, Val Loss={val_loss:.4f}')

writer.close()

The hierarchical naming (Loss/train, Loss/validation) organizes metrics into logical groups in TensorBoard’s interface. You can plot multiple scalars on the same graph by sharing prefixes.

Visualizing Model Architecture

Understanding your model’s computational graph helps debug architecture issues and verify gradient flow. TensorBoard can visualize the complete forward pass:

# Define a simple CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = torch.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

model = SimpleCNN()
writer = SummaryWriter('runs/cnn_architecture')

# Create dummy input matching your data shape
dummy_input = torch.randn(1, 1, 28, 28)

# Log the model graph
writer.add_graph(model, dummy_input)
writer.close()

The graph visualization shows layer connections, tensor shapes at each stage, and operation types. This is invaluable when debugging dimension mismatches or verifying that skip connections are properly configured.

Tracking Hyperparameters and Results

Comparing experiments with different hyperparameters is where TensorBoard truly shines. The add_hparams() method creates a structured comparison view:

def train_with_hparams(hparams):
    """Train model with given hyperparameters and return metrics."""
    model = SimpleCNN()
    optimizer = optim.Adam(model.parameters(), lr=hparams['lr'])
    criterion = nn.CrossEntropyLoss()
    
    # Training loop (simplified)
    for epoch in range(hparams['epochs']):
        # ... training code ...
        pass
    
    # Return final metrics
    return {
        'hparam/accuracy': final_accuracy,
        'hparam/loss': final_loss
    }

# Run multiple experiments
hparam_configs = [
    {'lr': 0.001, 'batch_size': 32, 'dropout': 0.5, 'epochs': 10},
    {'lr': 0.0001, 'batch_size': 64, 'dropout': 0.3, 'epochs': 10},
    {'lr': 0.01, 'batch_size': 16, 'dropout': 0.5, 'epochs': 10},
]

for i, hparams in enumerate(hparam_configs):
    writer = SummaryWriter(f'runs/hparam_search/run_{i}')
    
    # Train and get metrics
    metrics = train_with_hparams(hparams)
    
    # Log hyperparameters and results
    writer.add_hparams(
        hparams,
        metrics
    )
    writer.close()

TensorBoard’s HPARAMS tab displays a parallel coordinates plot and table view, making it easy to identify which hyperparameter combinations perform best.

Advanced Visualizations

Beyond scalars, TensorBoard supports rich visualizations that accelerate debugging:

writer = SummaryWriter('runs/advanced_viz')

# Log sample images with predictions
model.eval()
with torch.no_grad():
    data, target = next(iter(test_loader))
    output = model(data)
    _, predicted = torch.max(output, 1)
    
    # Add images with labels
    writer.add_images('predictions', data[:16], 0)
    
    # Log confusion matrix as image
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt
    import numpy as np
    
    all_preds = []
    all_targets = []
    
    for data, target in test_loader:
        output = model(data)
        _, preds = torch.max(output, 1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(target.cpu().numpy())
    
    cm = confusion_matrix(all_targets, all_preds)
    
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(cm, cmap='Blues')
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    
    writer.add_figure('confusion_matrix', fig, 0)

# Log embeddings for visualization
features = []
labels = []

with torch.no_grad():
    for data, target in test_loader:
        # Extract features from penultimate layer
        output = model.fc1(model.dropout1(
            torch.flatten(torch.max_pool2d(
                torch.relu(model.conv2(torch.relu(model.conv1(data)))), 2
            ), 1)
        ))
        features.append(output)
        labels.append(target)

features = torch.cat(features)
labels = torch.cat(labels)

writer.add_embedding(features, metadata=labels, label_img=data)
writer.close()

The embedding projector uses t-SNE or PCA to visualize high-dimensional features in 2D/3D, helping you understand if your model learns meaningful representations.

Best Practices & Tips

Effective TensorBoard usage requires discipline:

from contextlib import contextmanager
from pathlib import Path
from datetime import datetime

@contextmanager
def create_writer(experiment_name, hparams=None):
    """Context manager for proper SummaryWriter handling."""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_dir = Path('runs') / experiment_name / timestamp
    
    writer = SummaryWriter(log_dir)
    
    try:
        yield writer
    finally:
        # Ensure all data is written
        writer.flush()
        writer.close()

# Usage
with create_writer('my_experiment') as writer:
    for epoch in range(n_epochs):
        # Training code
        writer.add_scalar('Loss/train', loss, epoch)
        
        # Flush periodically during long runs
        if epoch % 10 == 0:
            writer.flush()

Directory organization matters: Use hierarchical structures like runs/project_name/experiment_type/timestamp to keep experiments organized. This makes it easy to compare related runs and clean up old experiments.

Naming conventions: Use consistent prefixes (Loss/, Accuracy/, Gradient/) to group related metrics. Avoid spaces in names—use underscores instead.

Flush regularly: Call writer.flush() periodically during long training runs to ensure data is written to disk. If your script crashes, unflushed data is lost.

Close writers: Always close writers when done, either explicitly or using context managers. Unclosed writers may not write all data.

Avoid logging too frequently: Logging every batch creates massive log files. Log batch-level metrics every N batches and epoch-level metrics every epoch.

TensorBoard transforms PyTorch development from printf debugging to systematic experiment tracking. The initial setup overhead pays dividends when you need to compare dozens of experiments or debug subtle training issues. Start with scalar logging, then gradually incorporate architecture visualization and advanced features as your needs grow.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.