How to Use Callbacks in TensorFlow

Callbacks are functions that execute at specific points during model training, giving you programmatic control over the training process. Instead of writing monolithic training loops with hardcoded...

Key Insights

  • Callbacks provide hooks into the training loop to monitor metrics, save checkpoints, adjust hyperparameters, and implement custom logic without modifying core training code
  • TensorFlow offers powerful built-in callbacks like ModelCheckpoint and EarlyStopping, but custom callbacks unlock advanced patterns like gradient monitoring, dynamic learning rate schedules, and integration with external services
  • Callback methods execute at specific training stages (epoch/batch begin/end), and understanding their execution order is critical for avoiding performance bottlenecks and implementing complex training workflows

Introduction to TensorFlow Callbacks

Callbacks are functions that execute at specific points during model training, giving you programmatic control over the training process. Instead of writing monolithic training loops with hardcoded logic, callbacks let you inject custom behavior in a modular, reusable way.

Think of callbacks as event listeners for your training loop. When certain events occur—an epoch completes, a batch finishes processing, or training begins—TensorFlow triggers the appropriate callback methods. This architecture separates concerns: your model definition stays clean, while training logic lives in composable callback objects.

Here’s where callbacks fit into the training workflow:

import tensorflow as tf
from tensorflow import keras

# Define your model
model = keras.Sequential([
    keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Callbacks inject custom logic into training
callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=3),
    keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)
]

# Training loop with callbacks
history = model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=50,
    callbacks=callbacks  # Callbacks execute automatically
)

Built-in Callbacks Overview

TensorFlow provides battle-tested callbacks for common training scenarios. Here are the ones you’ll use most frequently:

ModelCheckpoint saves your model at intervals, typically preserving only the best version based on a monitored metric. This prevents losing progress if training crashes and ensures you keep the optimal model state.

EarlyStopping halts training when a metric stops improving, preventing overfitting and saving computational resources.

TensorBoard logs metrics, histograms, and graphs for visualization in TensorBoard’s web interface.

ReduceLROnPlateau decreases the learning rate when a metric plateaus, helping models escape local minima.

CSVLogger streams epoch results to a CSV file for later analysis.

Here’s a realistic training setup combining multiple callbacks:

import tensorflow as tf
from tensorflow import keras
import os

# Create checkpoint directory
checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Configure callbacks
callbacks = [
    # Save best model based on validation accuracy
    keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(checkpoint_dir, 'model_{epoch:02d}_{val_accuracy:.3f}.h5'),
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    
    # Stop if validation loss doesn't improve for 5 epochs
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    
    # Reduce learning rate when validation loss plateaus
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-7,
        verbose=1
    ),
    
    # Log metrics to TensorBoard
    keras.callbacks.TensorBoard(
        log_dir='./logs',
        histogram_freq=1,
        write_graph=True
    )
]

model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=100,
    callbacks=callbacks
)

Creating Custom Callbacks

Built-in callbacks cover common scenarios, but custom callbacks unlock TensorFlow’s full potential. Subclass tf.keras.callbacks.Callback and override methods corresponding to training events:

  • on_train_begin/end: Called at training start/finish
  • on_epoch_begin/end: Called at epoch start/finish
  • on_batch_begin/end: Called before/after processing each batch
  • on_test_begin/end: Called during evaluation
  • on_predict_begin/end: Called during prediction

Here’s a custom callback that monitors learning rate and sends notifications:

import tensorflow as tf
from tensorflow import keras
import requests
import json

class LearningRateLogger(keras.callbacks.Callback):
    """Log learning rate and send Slack notifications at milestones."""
    
    def __init__(self, slack_webhook_url=None):
        super().__init__()
        self.slack_webhook_url = slack_webhook_url
        
    def on_epoch_end(self, epoch, logs=None):
        # Access current learning rate
        lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
        print(f"\nEpoch {epoch + 1}: Learning rate is {lr:.6f}")
        
        # Check accuracy milestones
        accuracy = logs.get('accuracy', 0)
        val_accuracy = logs.get('val_accuracy', 0)
        
        if val_accuracy > 0.95 and self.slack_webhook_url:
            self._send_slack_message(
                f"🎉 Model reached 95% validation accuracy at epoch {epoch + 1}!"
            )
    
    def _send_slack_message(self, message):
        """Send notification to Slack webhook."""
        try:
            payload = {'text': message}
            requests.post(
                self.slack_webhook_url,
                data=json.dumps(payload),
                headers={'Content-Type': 'application/json'}
            )
        except Exception as e:
            print(f"Failed to send Slack notification: {e}")

# Usage
callback = LearningRateLogger(slack_webhook_url='https://hooks.slack.com/services/YOUR/WEBHOOK/URL')
model.fit(x_train, y_train, epochs=50, callbacks=[callback])

Accessing Model Metrics and Parameters

Callbacks receive a logs dictionary containing current metrics and can access the model directly via self.model. This enables sophisticated monitoring and dynamic training adjustments.

Here’s a callback that tracks gradient norms and implements custom early stopping:

import tensorflow as tf
from tensorflow import keras
import numpy as np

class GradientMonitor(keras.callbacks.Callback):
    """Monitor gradient norms and implement custom early stopping."""
    
    def __init__(self, patience=5, grad_norm_threshold=100.0):
        super().__init__()
        self.patience = patience
        self.grad_norm_threshold = grad_norm_threshold
        self.wait = 0
        self.best_loss = np.inf
        self.gradient_norms = []
        
    def on_epoch_end(self, epoch, logs=None):
        # Calculate gradient norms
        gradients = []
        with tf.GradientTape() as tape:
            # Sample batch for gradient calculation
            sample_batch = next(iter(train_dataset))
            predictions = self.model(sample_batch[0], training=True)
            loss = self.model.compiled_loss(sample_batch[1], predictions)
        
        grads = tape.gradient(loss, self.model.trainable_weights)
        grad_norm = tf.linalg.global_norm(grads)
        self.gradient_norms.append(float(grad_norm))
        
        print(f"\nEpoch {epoch + 1}: Gradient norm = {grad_norm:.4f}")
        
        # Check for exploding gradients
        if grad_norm > self.grad_norm_threshold:
            print(f"⚠️  Warning: Gradient norm exceeds threshold!")
            self.model.stop_training = True
            return
        
        # Custom early stopping based on multiple metrics
        current_loss = logs.get('val_loss', np.inf)
        current_acc = logs.get('val_accuracy', 0)
        
        # Stop if loss increases AND accuracy doesn't improve
        if current_loss < self.best_loss * 0.99:  # 1% improvement threshold
            self.best_loss = current_loss
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience and current_acc < 0.9:
                print(f"\nEarly stopping: No improvement for {self.patience} epochs")
                self.model.stop_training = True
    
    def on_train_end(self, logs=None):
        print(f"\nGradient norm statistics:")
        print(f"  Mean: {np.mean(self.gradient_norms):.4f}")
        print(f"  Max: {np.max(self.gradient_norms):.4f}")
        print(f"  Min: {np.min(self.gradient_norms):.4f}")

Advanced Callback Patterns

Callbacks can maintain state across epochs and implement complex training schedules. Here’s a learning rate scheduler with warmup and cyclic patterns:

import tensorflow as tf
from tensorflow import keras
import numpy as np

class WarmupCyclicLR(keras.callbacks.Callback):
    """Learning rate scheduler with warmup and cyclic decay."""
    
    def __init__(self, base_lr=1e-4, max_lr=1e-2, warmup_epochs=5, 
                 cycle_length=10, decay_factor=0.9):
        super().__init__()
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.warmup_epochs = warmup_epochs
        self.cycle_length = cycle_length
        self.decay_factor = decay_factor
        self.history = []
        
    def on_epoch_begin(self, epoch, logs=None):
        # Warmup phase: linear increase
        if epoch < self.warmup_epochs:
            lr = self.base_lr + (self.max_lr - self.base_lr) * (epoch / self.warmup_epochs)
        else:
            # Cyclic phase with decay
            cycle_num = (epoch - self.warmup_epochs) // self.cycle_length
            cycle_pos = (epoch - self.warmup_epochs) % self.cycle_length
            
            # Cosine annealing within cycle
            cycle_progress = cycle_pos / self.cycle_length
            current_max_lr = self.max_lr * (self.decay_factor ** cycle_num)
            lr = self.base_lr + (current_max_lr - self.base_lr) * \
                 (1 + np.cos(np.pi * cycle_progress)) / 2
        
        tf.keras.backend.set_value(self.model.optimizer.lr, lr)
        self.history.append(lr)
        print(f"\nEpoch {epoch + 1}: Learning rate = {lr:.6f}")

# Usage with multiple callbacks (order matters!)
callbacks = [
    WarmupCyclicLR(base_lr=1e-5, max_lr=1e-3),  # LR scheduler runs first
    keras.callbacks.ModelCheckpoint('model.h5', save_best_only=True),
    keras.callbacks.EarlyStopping(patience=10)
]

Best Practices and Common Pitfalls

Avoid expensive operations in on_batch_end. This method executes after every batch—potentially thousands of times per epoch. Heavy computations here will cripple training speed.

# ❌ BAD: Expensive operation every batch
class SlowCallback(keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        # This runs thousands of times per epoch!
        heavy_computation = np.linalg.svd(self.model.layers[0].get_weights()[0])
        self.save_to_database(heavy_computation)

# ✅ GOOD: Expensive operations only per epoch
class FastCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Runs once per epoch
        if epoch % 5 == 0:  # Even less frequently
            heavy_computation = np.linalg.svd(self.model.layers[0].get_weights()[0])
            self.save_to_database(heavy_computation)

Use callback execution order strategically. Callbacks execute in list order, which matters when they interact. Learning rate schedulers should run before checkpointing to ensure saved models reflect the correct learning rate.

Cache expensive lookups. If you need model information repeatedly, compute it once in on_train_begin rather than recalculating in every epoch:

class OptimizedCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        # Cache layer references once
        self.monitored_layers = [
            layer for layer in self.model.layers 
            if isinstance(layer, keras.layers.Dense)
        ]
    
    def on_epoch_end(self, epoch, logs=None):
        # Use cached references
        for layer in self.monitored_layers:
            weights = layer.get_weights()
            # Process weights...

Test callbacks independently. Create minimal models and datasets to verify callback behavior before deploying to expensive training runs.

Callbacks transform TensorFlow from a static training framework into a dynamic, controllable system. Master them, and you’ll build more robust, observable, and efficient training pipelines.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.