Skip to main content

TensorFlow Callbacks

Introduction

When training deep learning models, you often want to monitor the training process, save checkpoints, stop training when certain conditions are met, or modify parameters dynamically. TensorFlow provides a powerful feature called callbacks that lets you perform these actions and more.

Callbacks are objects that are called at specific points during model training, such as at the start or end of an epoch, batch, or training session. They can access and modify the model state, training parameters, and metrics.

In this tutorial, you'll learn:

  • What callbacks are and why they're useful
  • Built-in callbacks in TensorFlow
  • How to create custom callbacks
  • Practical applications of callbacks in real-world scenarios

Why Use Callbacks?

Callbacks allow you to:

  • Monitor training progress: Track metrics during training and visualize them
  • Early stopping: Prevent overfitting by stopping training when the model stops improving
  • Save model checkpoints: Periodically save your model's weights
  • Adjust learning rates: Implement learning rate schedules to improve training efficiency
  • Log custom metrics: Track specialized metrics beyond standard loss and accuracy
  • Debug your model: Get insights into internal model states during training

Built-in Callbacks in TensorFlow

TensorFlow provides several built-in callbacks for common tasks. Let's explore the most useful ones:

ModelCheckpoint

This callback saves your model at regular intervals during training.

python
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint

# Create a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

# Create checkpoint callback
checkpoint_cb = ModelCheckpoint(
filepath='model_weights_{epoch:02d}_{val_accuracy:.3f}.h5',
save_best_only=True, # Save only when the model improves
monitor='val_accuracy', # The metric to monitor
mode='max', # 'max' because we want to maximize accuracy
verbose=1 # Show messages
)

# Train the model with the checkpoint callback
model.fit(x_train, y_train,
epochs=10,
validation_data=(x_val, y_val),
callbacks=[checkpoint_cb])

The callback above saves the model whenever validation accuracy improves, with filenames including both epoch number and accuracy.

EarlyStopping

EarlyStopping helps prevent overfitting by monitoring a specific metric and stopping training when it stops improving.

python
from tensorflow.keras.callbacks import EarlyStopping

# Create early stopping callback
early_stopping_cb = EarlyStopping(
monitor='val_loss', # Monitor validation loss
patience=3, # Stop after 3 epochs without improvement
min_delta=0.001, # Minimum change to qualify as improvement
mode='min', # 'min' because we want to minimize loss
restore_best_weights=True # Restore model weights from the epoch with the best value
)

# Train with the callback
history = model.fit(
x_train, y_train,
epochs=50, # Maximum number of epochs
validation_data=(x_val, y_val),
callbacks=[early_stopping_cb]
)

print(f"Training stopped after {len(history.history['loss'])} epochs")

In the above example, if the validation loss doesn't improve by at least 0.001 for 3 consecutive epochs, training will stop.

ReduceLROnPlateau

This callback reduces the learning rate when a metric has stopped improving, which can help find better local minima.

python
from tensorflow.keras.callbacks import ReduceLROnPlateau

reduce_lr = ReduceLROnPlateau(
monitor='val_loss',
factor=0.2, # Reduce learning rate by 80%
patience=2, # Wait 2 epochs without improvement
min_lr=1e-6, # Don't reduce learning rate below this value
verbose=1
)

model.fit(
x_train, y_train,
epochs=30,
validation_data=(x_val, y_val),
callbacks=[reduce_lr]
)

This is particularly useful for fine-tuning models, as reducing the learning rate can help the model converge to better results.

TensorBoard

TensorBoard is a visualization tool for TensorFlow that helps you track and visualize metrics, the model graph, and even visualize embeddings.

python
import datetime
from tensorflow.keras.callbacks import TensorBoard

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(
log_dir=log_dir,
histogram_freq=1, # Record histogram every epoch
write_graph=True, # Log the model graph
update_freq='epoch' # Update logs at the end of each epoch
)

model.fit(
x_train, y_train,
epochs=10,
validation_data=(x_val, y_val),
callbacks=[tensorboard_callback]
)

After training, you can view the logs using TensorBoard:

bash
tensorboard --logdir logs/fit

CSVLogger

This simple but useful callback logs training metrics to a CSV file, making it easy to analyze them later.

python
from tensorflow.keras.callbacks import CSVLogger

csv_logger = CSVLogger('training_history.csv', separator=',', append=False)

model.fit(
x_train, y_train,
epochs=10,
validation_data=(x_val, y_val),
callbacks=[csv_logger]
)

Combining Multiple Callbacks

You can use multiple callbacks simultaneously to get the benefits of each:

python
callbacks = [
ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy'),
EarlyStopping(patience=5, restore_best_weights=True),
ReduceLROnPlateau(factor=0.1, patience=3),
TensorBoard(log_dir='./logs'),
CSVLogger('training_log.csv')
]

model.fit(
x_train, y_train,
epochs=100,
validation_data=(x_val, y_val),
callbacks=callbacks
)

Creating Custom Callbacks

While the built-in callbacks are powerful, you might need custom functionality. You can create a custom callback by subclassing tf.keras.callbacks.Callback.

Here's an example of a custom callback that logs the learning rate at the end of each epoch:

python
class LearningRateLogger(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
lr = self.model.optimizer.lr
if hasattr(lr, 'value'):
lr_value = lr.value()
else:
lr_value = lr
print(f"\nEpoch {epoch+1}: Current learning rate: {lr_value}")

# Use the custom callback
model.fit(
x_train, y_train,
epochs=10,
callbacks=[LearningRateLogger()]
)

Available callback methods include:

  • on_train_begin, on_train_end
  • on_epoch_begin, on_epoch_end
  • on_batch_begin, on_batch_end
  • on_predict_begin, on_predict_end
  • on_predict_batch_begin, on_predict_batch_end
  • on_test_begin, on_test_end
  • on_test_batch_begin, on_test_batch_end

Real-World Example: Training with Learning Rate Scheduling and Monitoring

Let's put everything together in a more complex example. We'll create a model for the MNIST dataset with:

  1. Learning rate scheduling
  2. Early stopping
  3. Model checkpointing
  4. Custom callback for monitoring
python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, TensorBoard
import datetime

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0

# Reshape data for the model
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

# Split validation set from training set
val_size = 10000
x_val = x_train[-val_size:]
y_val = y_train[-val_size:]
x_train = x_train[:-val_size]
y_train = y_train[:-val_size]

# Create a convolutional model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation='softmax')
])

# Custom learning rate scheduler
def step_decay(epoch):
initial_lr = 0.001
drop_rate = 0.5
epochs_drop = 5
lr = initial_lr * (drop_rate ** np.floor((1 + epoch) / epochs_drop))
return lr

# Custom callback to track some details during training
class TrainingMonitor(tf.keras.callbacks.Callback):
def __init__(self, test_data):
self.test_data = test_data
self.test_images, self.test_labels = test_data
self.incorrect_predictions = []

def on_epoch_end(self, epoch, logs=None):
if epoch % 5 == 0: # Check every 5 epochs
predictions = self.model.predict(self.test_images)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = self.test_labels

# Find incorrectly predicted samples
incorrect = (predicted_classes != true_classes)
incorrect_indices = np.where(incorrect)[0]

if len(incorrect_indices) > 0:
# Store some incorrect predictions for later analysis
sample_idx = incorrect_indices[0]
self.incorrect_predictions.append({
'epoch': epoch,
'index': sample_idx,
'true_label': true_classes[sample_idx],
'predicted': predicted_classes[sample_idx],
'confidence': predictions[sample_idx][predicted_classes[sample_idx]]
})
print(f"\nSample incorrect prediction at index {sample_idx}:")
print(f"True label: {true_classes[sample_idx]}, Predicted: {predicted_classes[sample_idx]}")

# Set up all callbacks
lr_scheduler = LearningRateScheduler(step_decay)

early_stopping = EarlyStopping(
monitor='val_accuracy',
patience=10,
restore_best_weights=True
)

checkpoint = ModelCheckpoint(
'mnist_model_{epoch:02d}_{val_accuracy:.4f}.h5',
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=1
)

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard = TensorBoard(
log_dir=log_dir,
histogram_freq=1
)

training_monitor = TrainingMonitor(test_data=(x_test[:1000], y_test[:1000]))

# Compile the model
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Train the model with all callbacks
history = model.fit(
x_train, y_train,
epochs=30,
batch_size=64,
validation_data=(x_val, y_val),
callbacks=[
lr_scheduler,
early_stopping,
checkpoint,
tensorboard,
training_monitor
]
)

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"\nTest accuracy: {test_acc:.4f}")

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.tight_layout()
plt.show()

This example demonstrates a more comprehensive training setup that uses callbacks to:

  1. Decay the learning rate on a schedule
  2. Stop early if validation metrics don't improve
  3. Save the best model checkpoints
  4. Monitor and log training with TensorBoard
  5. Track specific details with a custom callback

Summary

TensorFlow callbacks provide a powerful mechanism to customize, monitor, and control the training process of your models. They help you:

  • Save the best models during training
  • Stop training early to prevent overfitting
  • Adjust learning rates adaptively
  • Log metrics and model states
  • Visualize training progress
  • Implement custom monitoring logic

Using callbacks effectively can significantly improve both your model performance and your development workflow.

Additional Resources

To learn more about TensorFlow callbacks, check out these resources:

  1. TensorFlow official documentation on callbacks
  2. TensorFlow Guide: Callbacks
  3. TensorBoard visualization tutorial

Exercises

  1. Create a custom callback that saves an image of the model's predictions on a fixed set of validation examples at the end of each epoch.
  2. Implement a "warmup" learning rate scheduler that slowly increases the learning rate for the first few epochs before applying a standard decay schedule.
  3. Create a callback that stops training if the model's accuracy on a specific test sample doesn't reach a threshold after a certain number of epochs.
  4. Use the CSVLogger callback to log training metrics, then write a script to analyze the resulting CSV file and plot learning curves.
  5. Combine EarlyStopping and ReduceLROnPlateau to create a training strategy that first reduces the learning rate when progress stalls and only stops training if reducing the learning rate doesn't help.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)