TensorFlow Callbacks
Introduction
When training deep learning models, you often want to monitor the training process, save checkpoints, stop training when certain conditions are met, or modify parameters dynamically. TensorFlow provides a powerful feature called callbacks that lets you perform these actions and more.
Callbacks are objects that are called at specific points during model training, such as at the start or end of an epoch, batch, or training session. They can access and modify the model state, training parameters, and metrics.
In this tutorial, you'll learn:
- What callbacks are and why they're useful
- Built-in callbacks in TensorFlow
- How to create custom callbacks
- Practical applications of callbacks in real-world scenarios
Why Use Callbacks?
Callbacks allow you to:
- Monitor training progress: Track metrics during training and visualize them
- Early stopping: Prevent overfitting by stopping training when the model stops improving
- Save model checkpoints: Periodically save your model's weights
- Adjust learning rates: Implement learning rate schedules to improve training efficiency
- Log custom metrics: Track specialized metrics beyond standard loss and accuracy
- Debug your model: Get insights into internal model states during training
Built-in Callbacks in TensorFlow
TensorFlow provides several built-in callbacks for common tasks. Let's explore the most useful ones:
ModelCheckpoint
This callback saves your model at regular intervals during training.
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
# Create a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Create checkpoint callback
checkpoint_cb = ModelCheckpoint(
filepath='model_weights_{epoch:02d}_{val_accuracy:.3f}.h5',
save_best_only=True, # Save only when the model improves
monitor='val_accuracy', # The metric to monitor
mode='max', # 'max' because we want to maximize accuracy
verbose=1 # Show messages
)
# Train the model with the checkpoint callback
model.fit(x_train, y_train,
epochs=10,
validation_data=(x_val, y_val),
callbacks=[checkpoint_cb])
The callback above saves the model whenever validation accuracy improves, with filenames including both epoch number and accuracy.
EarlyStopping
EarlyStopping helps prevent overfitting by monitoring a specific metric and stopping training when it stops improving.
from tensorflow.keras.callbacks import EarlyStopping
# Create early stopping callback
early_stopping_cb = EarlyStopping(
monitor='val_loss', # Monitor validation loss
patience=3, # Stop after 3 epochs without improvement
min_delta=0.001, # Minimum change to qualify as improvement
mode='min', # 'min' because we want to minimize loss
restore_best_weights=True # Restore model weights from the epoch with the best value
)
# Train with the callback
history = model.fit(
x_train, y_train,
epochs=50, # Maximum number of epochs
validation_data=(x_val, y_val),
callbacks=[early_stopping_cb]
)
print(f"Training stopped after {len(history.history['loss'])} epochs")
In the above example, if the validation loss doesn't improve by at least 0.001 for 3 consecutive epochs, training will stop.
ReduceLROnPlateau
This callback reduces the learning rate when a metric has stopped improving, which can help find better local minima.
from tensorflow.keras.callbacks import ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(
monitor='val_loss',
factor=0.2, # Reduce learning rate by 80%
patience=2, # Wait 2 epochs without improvement
min_lr=1e-6, # Don't reduce learning rate below this value
verbose=1
)
model.fit(
x_train, y_train,
epochs=30,
validation_data=(x_val, y_val),
callbacks=[reduce_lr]
)
This is particularly useful for fine-tuning models, as reducing the learning rate can help the model converge to better results.
TensorBoard
TensorBoard is a visualization tool for TensorFlow that helps you track and visualize metrics, the model graph, and even visualize embeddings.
import datetime
from tensorflow.keras.callbacks import TensorBoard
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(
log_dir=log_dir,
histogram_freq=1, # Record histogram every epoch
write_graph=True, # Log the model graph
update_freq='epoch' # Update logs at the end of each epoch
)
model.fit(
x_train, y_train,
epochs=10,
validation_data=(x_val, y_val),
callbacks=[tensorboard_callback]
)
After training, you can view the logs using TensorBoard:
tensorboard --logdir logs/fit
CSVLogger
This simple but useful callback logs training metrics to a CSV file, making it easy to analyze them later.
from tensorflow.keras.callbacks import CSVLogger
csv_logger = CSVLogger('training_history.csv', separator=',', append=False)
model.fit(
x_train, y_train,
epochs=10,
validation_data=(x_val, y_val),
callbacks=[csv_logger]
)
Combining Multiple Callbacks
You can use multiple callbacks simultaneously to get the benefits of each:
callbacks = [
ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy'),
EarlyStopping(patience=5, restore_best_weights=True),
ReduceLROnPlateau(factor=0.1, patience=3),
TensorBoard(log_dir='./logs'),
CSVLogger('training_log.csv')
]
model.fit(
x_train, y_train,
epochs=100,
validation_data=(x_val, y_val),
callbacks=callbacks
)
Creating Custom Callbacks
While the built-in callbacks are powerful, you might need custom functionality. You can create a custom callback by subclassing tf.keras.callbacks.Callback
.
Here's an example of a custom callback that logs the learning rate at the end of each epoch:
class LearningRateLogger(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
lr = self.model.optimizer.lr
if hasattr(lr, 'value'):
lr_value = lr.value()
else:
lr_value = lr
print(f"\nEpoch {epoch+1}: Current learning rate: {lr_value}")
# Use the custom callback
model.fit(
x_train, y_train,
epochs=10,
callbacks=[LearningRateLogger()]
)
Available callback methods include:
on_train_begin
,on_train_end
on_epoch_begin
,on_epoch_end
on_batch_begin
,on_batch_end
on_predict_begin
,on_predict_end
on_predict_batch_begin
,on_predict_batch_end
on_test_begin
,on_test_end
on_test_batch_begin
,on_test_batch_end
Real-World Example: Training with Learning Rate Scheduling and Monitoring
Let's put everything together in a more complex example. We'll create a model for the MNIST dataset with:
- Learning rate scheduling
- Early stopping
- Model checkpointing
- Custom callback for monitoring
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, TensorBoard
import datetime
# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Normalize pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0
# Reshape data for the model
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
# Split validation set from training set
val_size = 10000
x_val = x_train[-val_size:]
y_val = y_train[-val_size:]
x_train = x_train[:-val_size]
y_train = y_train[:-val_size]
# Create a convolutional model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation='softmax')
])
# Custom learning rate scheduler
def step_decay(epoch):
initial_lr = 0.001
drop_rate = 0.5
epochs_drop = 5
lr = initial_lr * (drop_rate ** np.floor((1 + epoch) / epochs_drop))
return lr
# Custom callback to track some details during training
class TrainingMonitor(tf.keras.callbacks.Callback):
def __init__(self, test_data):
self.test_data = test_data
self.test_images, self.test_labels = test_data
self.incorrect_predictions = []
def on_epoch_end(self, epoch, logs=None):
if epoch % 5 == 0: # Check every 5 epochs
predictions = self.model.predict(self.test_images)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = self.test_labels
# Find incorrectly predicted samples
incorrect = (predicted_classes != true_classes)
incorrect_indices = np.where(incorrect)[0]
if len(incorrect_indices) > 0:
# Store some incorrect predictions for later analysis
sample_idx = incorrect_indices[0]
self.incorrect_predictions.append({
'epoch': epoch,
'index': sample_idx,
'true_label': true_classes[sample_idx],
'predicted': predicted_classes[sample_idx],
'confidence': predictions[sample_idx][predicted_classes[sample_idx]]
})
print(f"\nSample incorrect prediction at index {sample_idx}:")
print(f"True label: {true_classes[sample_idx]}, Predicted: {predicted_classes[sample_idx]}")
# Set up all callbacks
lr_scheduler = LearningRateScheduler(step_decay)
early_stopping = EarlyStopping(
monitor='val_accuracy',
patience=10,
restore_best_weights=True
)
checkpoint = ModelCheckpoint(
'mnist_model_{epoch:02d}_{val_accuracy:.4f}.h5',
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=1
)
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard = TensorBoard(
log_dir=log_dir,
histogram_freq=1
)
training_monitor = TrainingMonitor(test_data=(x_test[:1000], y_test[:1000]))
# Compile the model
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Train the model with all callbacks
history = model.fit(
x_train, y_train,
epochs=30,
batch_size=64,
validation_data=(x_val, y_val),
callbacks=[
lr_scheduler,
early_stopping,
checkpoint,
tensorboard,
training_monitor
]
)
# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"\nTest accuracy: {test_acc:.4f}")
# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.tight_layout()
plt.show()
This example demonstrates a more comprehensive training setup that uses callbacks to:
- Decay the learning rate on a schedule
- Stop early if validation metrics don't improve
- Save the best model checkpoints
- Monitor and log training with TensorBoard
- Track specific details with a custom callback
Summary
TensorFlow callbacks provide a powerful mechanism to customize, monitor, and control the training process of your models. They help you:
- Save the best models during training
- Stop training early to prevent overfitting
- Adjust learning rates adaptively
- Log metrics and model states
- Visualize training progress
- Implement custom monitoring logic
Using callbacks effectively can significantly improve both your model performance and your development workflow.
Additional Resources
To learn more about TensorFlow callbacks, check out these resources:
- TensorFlow official documentation on callbacks
- TensorFlow Guide: Callbacks
- TensorBoard visualization tutorial
Exercises
- Create a custom callback that saves an image of the model's predictions on a fixed set of validation examples at the end of each epoch.
- Implement a "warmup" learning rate scheduler that slowly increases the learning rate for the first few epochs before applying a standard decay schedule.
- Create a callback that stops training if the model's accuracy on a specific test sample doesn't reach a threshold after a certain number of epochs.
- Use the
CSVLogger
callback to log training metrics, then write a script to analyze the resulting CSV file and plot learning curves. - Combine
EarlyStopping
andReduceLROnPlateau
to create a training strategy that first reduces the learning rate when progress stalls and only stops training if reducing the learning rate doesn't help.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)