How to Use Custom Training Loops in TensorFlow

TensorFlow's `model.fit()` is convenient and handles most standard training scenarios with minimal code. It automatically manages the training loop, metrics tracking, callbacks, and even distributed...

Key Insights

  • Custom training loops give you complete control over the training process, essential for research experimentation, multi-optimizer scenarios like GANs, and complex gradient manipulation that model.fit() can’t handle elegantly.
  • The core pattern revolves around tf.GradientTape for automatic differentiation—you manually execute the forward pass, compute loss, calculate gradients, and apply them through optimizers, giving you full transparency into each training step.
  • Wrapping custom loops with @tf.function can provide 10-20x speedups by converting eager execution to graph mode, but requires careful handling of Python control flow and TensorFlow operations to avoid performance pitfalls.

Introduction & When to Use Custom Training Loops

TensorFlow’s model.fit() is convenient and handles most standard training scenarios with minimal code. It automatically manages the training loop, metrics tracking, callbacks, and even distributed training. However, the moment you need fine-grained control over the training process, you’ll hit its limitations.

Custom training loops become necessary when you’re implementing research papers with novel training procedures, working with multiple models that need separate optimizers (like GANs or actor-critic methods), applying custom gradient transformations, or debugging training dynamics at a granular level. If you’ve ever tried forcing complex logic into a custom callback or loss function, you know how awkward it feels—that’s when you need a custom loop.

The tradeoff is straightforward: you write more code but gain complete transparency and control. You’ll see exactly what happens at each training step, making debugging easier and experimentation more flexible.

Basic Custom Training Loop Structure

A custom training loop replaces the black box of model.fit() with explicit steps: iterate through batches, perform a forward pass inside tf.GradientTape, compute loss, calculate gradients, and apply them via an optimizer.

Here’s a minimal example training a simple neural network on MNIST:

import tensorflow as tf
from tensorflow import keras

# Load and preprocess data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
y_train = y_train.astype('int32')

# Create dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).batch(32)

# Define model
model = keras.Sequential([
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10)
])

# Define optimizer and loss
optimizer = keras.optimizers.Adam(learning_rate=0.001)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Training loop
epochs = 5
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    for step, (x_batch, y_batch) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            # Forward pass
            logits = model(x_batch, training=True)
            # Compute loss
            loss_value = loss_fn(y_batch, logits)
        
        # Compute gradients
        gradients = tape.gradient(loss_value, model.trainable_variables)
        # Apply gradients
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        if step % 200 == 0:
            print(f"Step {step}: Loss = {loss_value:.4f}")

The tf.GradientTape context manager records operations for automatic differentiation. Everything inside the tape is tracked, allowing tape.gradient() to compute derivatives of the loss with respect to model parameters.

Implementing Custom Metrics and Logging

Tracking metrics properly requires resetting them between epochs and updating them per batch. TensorFlow provides stateful metrics that handle averaging automatically:

# Define metrics
train_loss_metric = keras.metrics.Mean(name='train_loss')
train_accuracy_metric = keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

epochs = 5
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    # Reset metrics at the start of each epoch
    train_loss_metric.reset_states()
    train_accuracy_metric.reset_states()
    
    for step, (x_batch, y_batch) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch, training=True)
            loss_value = loss_fn(y_batch, logits)
        
        gradients = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        # Update metrics
        train_loss_metric.update_state(loss_value)
        train_accuracy_metric.update_state(y_batch, logits)
        
        if step % 200 == 0:
            print(f"Step {step}: Loss = {train_loss_metric.result():.4f}, "
                  f"Accuracy = {train_accuracy_metric.result():.4f}")
    
    # Display epoch results
    print(f"Epoch {epoch + 1} - Loss: {train_loss_metric.result():.4f}, "
          f"Accuracy: {train_accuracy_metric.result():.4f}")

The Mean metric automatically averages loss values across batches, while SparseCategoricalAccuracy tracks correct predictions. Always call reset_states() at the beginning of each epoch to avoid carrying over statistics.

Advanced Techniques: Multiple Optimizers and Custom Gradients

GANs and other adversarial architectures require separate optimizers for different model components. You also frequently need gradient clipping to stabilize training:

# Define generator and discriminator
generator = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=(100,)),
    keras.layers.Dense(784, activation='sigmoid')
])

discriminator = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    keras.layers.Dense(1)
])

# Separate optimizers
gen_optimizer = keras.optimizers.Adam(learning_rate=0.0002)
disc_optimizer = keras.optimizers.Adam(learning_rate=0.0002)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)

def train_step(real_images):
    batch_size = tf.shape(real_images)[0]
    noise = tf.random.normal([batch_size, 100])
    
    # Train discriminator
    with tf.GradientTape() as disc_tape:
        fake_images = generator(noise, training=True)
        
        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(fake_images, training=True)
        
        real_loss = loss_fn(tf.ones_like(real_output), real_output)
        fake_loss = loss_fn(tf.zeros_like(fake_output), fake_output)
        disc_loss = real_loss + fake_loss
    
    disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    # Gradient clipping
    disc_gradients = [tf.clip_by_norm(g, 1.0) for g in disc_gradients]
    disc_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
    
    # Train generator
    with tf.GradientTape() as gen_tape:
        fake_images = generator(noise, training=True)
        fake_output = discriminator(fake_images, training=True)
        gen_loss = loss_fn(tf.ones_like(fake_output), fake_output)
    
    gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gen_gradients = [tf.clip_by_norm(g, 1.0) for g in gen_gradients]
    gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
    
    return disc_loss, gen_loss

This pattern uses separate GradientTape contexts for each model and applies gradient clipping with tf.clip_by_norm to prevent exploding gradients—a common issue in adversarial training.

Distributed Training with Custom Loops

TensorFlow’s distribution strategies work seamlessly with custom loops. Wrap your training step in a strategy scope and use strategy.run():

# Create distribution strategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

with strategy.scope():
    # Models and optimizers must be created within strategy scope
    model = keras.Sequential([
        keras.layers.Dense(128, activation='relu', input_shape=(784,)),
        keras.layers.Dense(10)
    ])
    optimizer = keras.optimizers.Adam()
    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, 
                                                          reduction=tf.keras.losses.Reduction.NONE)

# Distribute dataset
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)

@tf.function
def distributed_train_step(inputs):
    def train_step(inputs):
        x_batch, y_batch = inputs
        with tf.GradientTape() as tape:
            logits = model(x_batch, training=True)
            # Compute per-example loss
            loss_value = loss_fn(y_batch, logits)
            # Scale loss by global batch size
            loss_value = tf.reduce_sum(loss_value) / tf.cast(
                tf.shape(y_batch)[0] * strategy.num_replicas_in_sync, tf.float32)
        
        gradients = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return loss_value
    
    # Run training step on all replicas
    per_replica_losses = strategy.run(train_step, args=(inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

for epoch in range(3):
    for inputs in train_dist_dataset:
        loss = distributed_train_step(inputs)

The key difference is using Reduction.NONE for the loss function and manually scaling it by the global batch size. The strategy handles gradient synchronization across devices automatically.

Validation, Checkpointing, and Early Stopping

Production training loops need validation, checkpointing, and early stopping. Here’s a complete implementation:

# Setup
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, './checkpoints', max_to_keep=3)

val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

best_val_loss = float('inf')
patience = 3
patience_counter = 0

@tf.function
def train_step(x_batch, y_batch):
    with tf.GradientTape() as tape:
        logits = model(x_batch, training=True)
        loss_value = loss_fn(y_batch, logits)
    gradients = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss_value

@tf.function
def val_step(x_batch, y_batch):
    logits = model(x_batch, training=False)
    return loss_fn(y_batch, logits)

for epoch in range(50):
    # Training
    train_loss_metric.reset_states()
    for x_batch, y_batch in train_dataset:
        loss = train_step(x_batch, y_batch)
        train_loss_metric.update_state(loss)
    
    # Validation
    val_loss_metric = keras.metrics.Mean()
    for x_batch, y_batch in val_dataset:
        val_loss = val_step(x_batch, y_batch)
        val_loss_metric.update_state(val_loss)
    
    val_loss = val_loss_metric.result()
    print(f"Epoch {epoch + 1}: Train Loss = {train_loss_metric.result():.4f}, "
          f"Val Loss = {val_loss:.4f}")
    
    # Checkpointing
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint_manager.save()
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch + 1}")
        break

Best Practices and Performance Optimization

The @tf.function decorator compiles Python functions into TensorFlow graphs, dramatically improving performance:

import time

# Eager execution (slow)
def eager_train_step(x_batch, y_batch):
    with tf.GradientTape() as tape:
        logits = model(x_batch, training=True)
        loss_value = loss_fn(y_batch, logits)
    gradients = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss_value

# Graph execution (fast)
@tf.function
def graph_train_step(x_batch, y_batch):
    with tf.GradientTape() as tape:
        logits = model(x_batch, training=True)
        loss_value = loss_fn(y_batch, logits)
    gradients = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss_value

# Benchmark
sample_batch = next(iter(train_dataset))

start = time.time()
for _ in range(100):
    eager_train_step(*sample_batch)
eager_time = time.time() - start

start = time.time()
for _ in range(100):
    graph_train_step(*sample_batch)
graph_time = time.time() - start

print(f"Eager execution: {eager_time:.3f}s")
print(f"Graph execution: {graph_time:.3f}s")
print(f"Speedup: {eager_time / graph_time:.1f}x")

Key optimization tips:

  • Always use @tf.function for training steps in production
  • Avoid Python side effects (printing, appending to lists) inside @tf.function
  • Use tf.data API for input pipelines with prefetching and parallelization
  • Profile with TensorBoard to identify bottlenecks
  • When in doubt, stick with model.fit() for standard supervised learning—custom loops add complexity that’s only justified when you need the control

Custom training loops are powerful but come with responsibility. Use them when you genuinely need the flexibility, and always benchmark against model.fit() to ensure the added complexity is worthwhile.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.