How to Use Callbacks in TensorFlow
Callbacks are functions that execute at specific points during model training, giving you programmatic control over the training process. Instead of writing monolithic training loops with hardcoded...
Key Insights
- Callbacks provide hooks into the training loop to monitor metrics, save checkpoints, adjust hyperparameters, and implement custom logic without modifying core training code
- TensorFlow offers powerful built-in callbacks like ModelCheckpoint and EarlyStopping, but custom callbacks unlock advanced patterns like gradient monitoring, dynamic learning rate schedules, and integration with external services
- Callback methods execute at specific training stages (epoch/batch begin/end), and understanding their execution order is critical for avoiding performance bottlenecks and implementing complex training workflows
Introduction to TensorFlow Callbacks
Callbacks are functions that execute at specific points during model training, giving you programmatic control over the training process. Instead of writing monolithic training loops with hardcoded logic, callbacks let you inject custom behavior in a modular, reusable way.
Think of callbacks as event listeners for your training loop. When certain events occur—an epoch completes, a batch finishes processing, or training begins—TensorFlow triggers the appropriate callback methods. This architecture separates concerns: your model definition stays clean, while training logic lives in composable callback objects.
Here’s where callbacks fit into the training workflow:
import tensorflow as tf
from tensorflow import keras
# Define your model
model = keras.Sequential([
keras.layers.Dense(128, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Callbacks inject custom logic into training
callbacks = [
keras.callbacks.EarlyStopping(monitor='val_loss', patience=3),
keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)
]
# Training loop with callbacks
history = model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=50,
callbacks=callbacks # Callbacks execute automatically
)
Built-in Callbacks Overview
TensorFlow provides battle-tested callbacks for common training scenarios. Here are the ones you’ll use most frequently:
ModelCheckpoint saves your model at intervals, typically preserving only the best version based on a monitored metric. This prevents losing progress if training crashes and ensures you keep the optimal model state.
EarlyStopping halts training when a metric stops improving, preventing overfitting and saving computational resources.
TensorBoard logs metrics, histograms, and graphs for visualization in TensorBoard’s web interface.
ReduceLROnPlateau decreases the learning rate when a metric plateaus, helping models escape local minima.
CSVLogger streams epoch results to a CSV file for later analysis.
Here’s a realistic training setup combining multiple callbacks:
import tensorflow as tf
from tensorflow import keras
import os
# Create checkpoint directory
checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
# Configure callbacks
callbacks = [
# Save best model based on validation accuracy
keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_dir, 'model_{epoch:02d}_{val_accuracy:.3f}.h5'),
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=1
),
# Stop if validation loss doesn't improve for 5 epochs
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True,
verbose=1
),
# Reduce learning rate when validation loss plateaus
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=3,
min_lr=1e-7,
verbose=1
),
# Log metrics to TensorBoard
keras.callbacks.TensorBoard(
log_dir='./logs',
histogram_freq=1,
write_graph=True
)
]
model.fit(
train_dataset,
validation_data=val_dataset,
epochs=100,
callbacks=callbacks
)
Creating Custom Callbacks
Built-in callbacks cover common scenarios, but custom callbacks unlock TensorFlow’s full potential. Subclass tf.keras.callbacks.Callback and override methods corresponding to training events:
on_train_begin/end: Called at training start/finishon_epoch_begin/end: Called at epoch start/finishon_batch_begin/end: Called before/after processing each batchon_test_begin/end: Called during evaluationon_predict_begin/end: Called during prediction
Here’s a custom callback that monitors learning rate and sends notifications:
import tensorflow as tf
from tensorflow import keras
import requests
import json
class LearningRateLogger(keras.callbacks.Callback):
"""Log learning rate and send Slack notifications at milestones."""
def __init__(self, slack_webhook_url=None):
super().__init__()
self.slack_webhook_url = slack_webhook_url
def on_epoch_end(self, epoch, logs=None):
# Access current learning rate
lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
print(f"\nEpoch {epoch + 1}: Learning rate is {lr:.6f}")
# Check accuracy milestones
accuracy = logs.get('accuracy', 0)
val_accuracy = logs.get('val_accuracy', 0)
if val_accuracy > 0.95 and self.slack_webhook_url:
self._send_slack_message(
f"🎉 Model reached 95% validation accuracy at epoch {epoch + 1}!"
)
def _send_slack_message(self, message):
"""Send notification to Slack webhook."""
try:
payload = {'text': message}
requests.post(
self.slack_webhook_url,
data=json.dumps(payload),
headers={'Content-Type': 'application/json'}
)
except Exception as e:
print(f"Failed to send Slack notification: {e}")
# Usage
callback = LearningRateLogger(slack_webhook_url='https://hooks.slack.com/services/YOUR/WEBHOOK/URL')
model.fit(x_train, y_train, epochs=50, callbacks=[callback])
Accessing Model Metrics and Parameters
Callbacks receive a logs dictionary containing current metrics and can access the model directly via self.model. This enables sophisticated monitoring and dynamic training adjustments.
Here’s a callback that tracks gradient norms and implements custom early stopping:
import tensorflow as tf
from tensorflow import keras
import numpy as np
class GradientMonitor(keras.callbacks.Callback):
"""Monitor gradient norms and implement custom early stopping."""
def __init__(self, patience=5, grad_norm_threshold=100.0):
super().__init__()
self.patience = patience
self.grad_norm_threshold = grad_norm_threshold
self.wait = 0
self.best_loss = np.inf
self.gradient_norms = []
def on_epoch_end(self, epoch, logs=None):
# Calculate gradient norms
gradients = []
with tf.GradientTape() as tape:
# Sample batch for gradient calculation
sample_batch = next(iter(train_dataset))
predictions = self.model(sample_batch[0], training=True)
loss = self.model.compiled_loss(sample_batch[1], predictions)
grads = tape.gradient(loss, self.model.trainable_weights)
grad_norm = tf.linalg.global_norm(grads)
self.gradient_norms.append(float(grad_norm))
print(f"\nEpoch {epoch + 1}: Gradient norm = {grad_norm:.4f}")
# Check for exploding gradients
if grad_norm > self.grad_norm_threshold:
print(f"⚠️ Warning: Gradient norm exceeds threshold!")
self.model.stop_training = True
return
# Custom early stopping based on multiple metrics
current_loss = logs.get('val_loss', np.inf)
current_acc = logs.get('val_accuracy', 0)
# Stop if loss increases AND accuracy doesn't improve
if current_loss < self.best_loss * 0.99: # 1% improvement threshold
self.best_loss = current_loss
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience and current_acc < 0.9:
print(f"\nEarly stopping: No improvement for {self.patience} epochs")
self.model.stop_training = True
def on_train_end(self, logs=None):
print(f"\nGradient norm statistics:")
print(f" Mean: {np.mean(self.gradient_norms):.4f}")
print(f" Max: {np.max(self.gradient_norms):.4f}")
print(f" Min: {np.min(self.gradient_norms):.4f}")
Advanced Callback Patterns
Callbacks can maintain state across epochs and implement complex training schedules. Here’s a learning rate scheduler with warmup and cyclic patterns:
import tensorflow as tf
from tensorflow import keras
import numpy as np
class WarmupCyclicLR(keras.callbacks.Callback):
"""Learning rate scheduler with warmup and cyclic decay."""
def __init__(self, base_lr=1e-4, max_lr=1e-2, warmup_epochs=5,
cycle_length=10, decay_factor=0.9):
super().__init__()
self.base_lr = base_lr
self.max_lr = max_lr
self.warmup_epochs = warmup_epochs
self.cycle_length = cycle_length
self.decay_factor = decay_factor
self.history = []
def on_epoch_begin(self, epoch, logs=None):
# Warmup phase: linear increase
if epoch < self.warmup_epochs:
lr = self.base_lr + (self.max_lr - self.base_lr) * (epoch / self.warmup_epochs)
else:
# Cyclic phase with decay
cycle_num = (epoch - self.warmup_epochs) // self.cycle_length
cycle_pos = (epoch - self.warmup_epochs) % self.cycle_length
# Cosine annealing within cycle
cycle_progress = cycle_pos / self.cycle_length
current_max_lr = self.max_lr * (self.decay_factor ** cycle_num)
lr = self.base_lr + (current_max_lr - self.base_lr) * \
(1 + np.cos(np.pi * cycle_progress)) / 2
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
self.history.append(lr)
print(f"\nEpoch {epoch + 1}: Learning rate = {lr:.6f}")
# Usage with multiple callbacks (order matters!)
callbacks = [
WarmupCyclicLR(base_lr=1e-5, max_lr=1e-3), # LR scheduler runs first
keras.callbacks.ModelCheckpoint('model.h5', save_best_only=True),
keras.callbacks.EarlyStopping(patience=10)
]
Best Practices and Common Pitfalls
Avoid expensive operations in on_batch_end. This method executes after every batch—potentially thousands of times per epoch. Heavy computations here will cripple training speed.
# ❌ BAD: Expensive operation every batch
class SlowCallback(keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None):
# This runs thousands of times per epoch!
heavy_computation = np.linalg.svd(self.model.layers[0].get_weights()[0])
self.save_to_database(heavy_computation)
# ✅ GOOD: Expensive operations only per epoch
class FastCallback(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
# Runs once per epoch
if epoch % 5 == 0: # Even less frequently
heavy_computation = np.linalg.svd(self.model.layers[0].get_weights()[0])
self.save_to_database(heavy_computation)
Use callback execution order strategically. Callbacks execute in list order, which matters when they interact. Learning rate schedulers should run before checkpointing to ensure saved models reflect the correct learning rate.
Cache expensive lookups. If you need model information repeatedly, compute it once in on_train_begin rather than recalculating in every epoch:
class OptimizedCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
# Cache layer references once
self.monitored_layers = [
layer for layer in self.model.layers
if isinstance(layer, keras.layers.Dense)
]
def on_epoch_end(self, epoch, logs=None):
# Use cached references
for layer in self.monitored_layers:
weights = layer.get_weights()
# Process weights...
Test callbacks independently. Create minimal models and datasets to verify callback behavior before deploying to expensive training runs.
Callbacks transform TensorFlow from a static training framework into a dynamic, controllable system. Master them, and you’ll build more robust, observable, and efficient training pipelines.