TensorFlow Checkpoints
Introduction
When training machine learning models, especially complex neural networks that may take hours or days to train, it's critical to save your progress. TensorFlow Checkpoints provide a mechanism to save and restore the state of your model during training, allowing you to:
- Resume training after interruption
- Prevent data loss in case of system failures
- Share pre-trained models with others
- Deploy models to production
In this tutorial, we'll explore how to implement checkpoints in TensorFlow, why they're essential, and best practices for managing them in your machine learning projects.
What are TensorFlow Checkpoints?
Checkpoints are files that contain the weights and states of your model at a specific point during training. Unlike saving the entire model architecture, checkpoints primarily store:
- Model weights/parameters
- Optimizer state (important for resuming training with momentum)
- Training configuration
- Current epoch/step
Think of checkpoints as "snapshots" of your model's training progress that you can return to at any time.
Basic Checkpoint Implementation
Let's start with a simple example of how to save and load checkpoints in TensorFlow:
import tensorflow as tf
import numpy as np
# Create a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# Compile the model
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Create a checkpoint callback
checkpoint_path = "training_checkpoints/cp-{epoch:04d}.ckpt"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
period=5 # Save weights every 5 epochs
)
# Create some dummy data for demonstration
x_train = np.random.random((1000, 784))
y_train = np.random.randint(0, 10, (1000,))
# Train the model with the checkpoint callback
model.fit(
x_train, y_train,
epochs=15,
callbacks=[checkpoint_callback],
validation_split=0.2
)
print("Training completed with checkpoints saved!")
When you run this code, you'll see output similar to:
Epoch 1/15
25/25 [==============================] - 2s 22ms/step - loss: 2.3026 - accuracy: 0.1020 - val_loss: 2.3026 - val_accuracy: 0.1050
...
Epoch 5/15
25/25 [==============================] - 1s 21ms/step - loss: 2.2701 - accuracy: 0.1434 - val_loss: 2.2687 - val_accuracy: 0.1500
Saving weights to training_checkpoints/cp-0005.ckpt
...
Epoch 10/15
25/25 [==============================] - 1s 22ms/step - loss: 2.1463 - accuracy: 0.2030 - val_loss: 2.1442 - val_accuracy: 0.2300
Saving weights to training_checkpoints/cp-0010.ckpt
...
Epoch 15/15
25/25 [==============================] - 1s 21ms/step - loss: 1.9830 - accuracy: 0.3064 - val_loss: 1.9801 - val_accuracy: 0.3050
Saving weights to training_checkpoints/cp-0015.ckpt
Training completed with checkpoints saved!
Loading from a Checkpoint
After creating checkpoints, you'll want to be able to restore your model. Here's how to load weights from the latest checkpoint:
# Create a fresh model
fresh_model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# Compile the model
fresh_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Load the latest checkpoint
latest_checkpoint = tf.train.latest_checkpoint("training_checkpoints")
fresh_model.load_weights(latest_checkpoint)
# Evaluate the restored model
loss, acc = fresh_model.evaluate(x_train, y_train, verbose=2)
print(f"Restored model, accuracy: {acc:.5f}")
Output:
31/31 - 0s - loss: 1.9830 - accuracy: 0.3064
Restored model, accuracy: 0.30640
Checkpoint Callback Options
The ModelCheckpoint
callback has several useful parameters:
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath="model_checkpoint.h5",
save_best_only=True, # Only save when validation improves
save_weights_only=False, # Save the entire model, not just weights
monitor='val_accuracy', # Metric to monitor
mode='max', # 'max' means higher is better
verbose=1 # Print saving messages
)
Custom Checkpoint Manager
For more control over your checkpoints, TensorFlow provides the tf.train.CheckpointManager
. This is particularly useful for:
- Managing the number of checkpoints you keep
- Automatically deleting old checkpoints
- Creating custom checkpoint logic
Here's how to implement it:
import tensorflow as tf
import os
# Create model and optimizer
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# Create checkpoint objects
checkpoint_dir = './tf_ckpts'
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
checkpoint,
directory=checkpoint_dir,
max_to_keep=3 # Keep only 3 latest checkpoints
)
# Training loop with manual checkpoint saving
for epoch in range(5):
# Simulating training progress
loss = 1.0/(epoch+1)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
# Save checkpoint at the end of every epoch
save_path = manager.save()
print(f"Saved checkpoint: {save_path}")
print(f"Latest checkpoint: {manager.latest_checkpoint}")
Output:
Epoch 1, Loss: 1.0000
Saved checkpoint: ./tf_ckpts/ckpt-1
Epoch 2, Loss: 0.5000
Saved checkpoint: ./tf_ckpts/ckpt-2
Epoch 3, Loss: 0.3333
Saved checkpoint: ./tf_ckpts/ckpt-3
Epoch 4, Loss: 0.2500
Saved checkpoint: ./tf_ckpts/ckpt-4
Epoch 5, Loss: 0.2000
Saved checkpoint: ./tf_ckpts/ckpt-5
Latest checkpoint: ./tf_ckpts/ckpt-5
Restoring from Custom Checkpoints
To restore a model using the Checkpoint
and CheckpointManager
objects:
# Create a new model and optimizer
new_model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])
new_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# Create a checkpoint object
new_checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
# Restore from the latest checkpoint
new_checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
print(f"Model restored from {tf.train.latest_checkpoint(checkpoint_dir)}")
Output:
Model restored from ./tf_ckpts/ckpt-5
Real-world Example: Resuming Training
Let's demonstrate a practical scenario where training is interrupted and needs to be resumed:
import tensorflow as tf
import numpy as np
import os
# Function to create our model architecture
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.4),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-4),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
# Setup checkpoint directory
checkpoint_dir = './training_resume_checkpoints'
checkpoint_path = os.path.join(checkpoint_dir, 'model_checkpoint.ckpt')
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Create model and checkpoint callback
model = create_model()
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
verbose=1,
save_freq='epoch' # Save every epoch
)
# Create sample data
x_train = np.random.random((1000, 784))
y_train = np.random.randint(0, 10, (1000,))
# Initial training (interrupted after 3 epochs)
print("Starting initial training for 3 epochs...")
initial_epochs = 3
history = model.fit(
x_train, y_train,
epochs=initial_epochs,
callbacks=[cp_callback],
validation_split=0.2
)
# Simulate training interruption
print("\nTraining interrupted! Now resuming from checkpoint...")
# Create a fresh model
model = create_model()
# Load the previously saved weights
model.load_weights(checkpoint_path)
# Resume training for more epochs
total_epochs = 6
history = model.fit(
x_train, y_train,
epochs=total_epochs,
initial_epoch=initial_epochs, # Start from epoch 3
callbacks=[cp_callback],
validation_split=0.2
)
print("Training completed successfully!")
Output:
Starting initial training for 3 epochs...
Epoch 1/3
25/25 [==============================] - 2s 81ms/step - loss: 2.3019 - accuracy: 0.1070 - val_loss: 2.3031 - val_accuracy: 0.0850
Saving model to training_resume_checkpoints/model_checkpoint.ckpt
Epoch 2/3
25/25 [==============================] - 2s 67ms/step - loss: 2.2968 - accuracy: 0.1060 - val_loss: 2.2978 - val_accuracy: 0.1050
Saving model to training_resume_checkpoints/model_checkpoint.ckpt
Epoch 3/3
25/25 [==============================] - 2s 68ms/step - loss: 2.2918 - accuracy: 0.1090 - val_loss: 2.2927 - val_accuracy: 0.1050
Saving model to training_resume_checkpoints/model_checkpoint.ckpt
Training interrupted! Now resuming from checkpoint...
Epoch 3/6
25/25 [==============================] - 2s 69ms/step - loss: 2.2918 - accuracy: 0.1150 - val_loss: 2.2926 - val_accuracy: 0.1000
Saving model to training_resume_checkpoints/model_checkpoint.ckpt
Epoch 4/6
25/25 [==============================] - 2s 67ms/step - loss: 2.2867 - accuracy: 0.1180 - val_loss: 2.2877 - val_accuracy: 0.1050
Saving model to training_resume_checkpoints/model_checkpoint.ckpt
Epoch 5/6
25/25 [==============================] - 2s 68ms/step - loss: 2.2818 - accuracy: 0.1320 - val_loss: 2.2828 - val_accuracy: 0.1000
Saving model to training_resume_checkpoints/model_checkpoint.ckpt
Epoch 6/6
25/25 [==============================] - 2s 67ms/step - loss: 2.2768 - accuracy: 0.1350 - val_loss: 2.2780 - val_accuracy: 0.1100
Saving model to training_resume_checkpoints/model_checkpoint.ckpt
Training completed successfully!
Best Practices for Using Checkpoints
- Save Regularly: Create checkpoints at reasonable intervals (every epoch or every N steps)
- Manage Disk Space: Don't keep too many checkpoints unless necessary
- Use
save_best_only
: Save only when your model improves on a key metric - Version Your Checkpoints: Include epoch number or timestamp in filenames
- Include Optimizer State: Save both model weights and optimizer state for proper training resumption
- Keep Backup Copies: For critical models, maintain backup copies in different locations
Checkpoints vs. Saved Models
TensorFlow offers multiple ways to save models. Here's how checkpoints compare to other methods:
Feature | Checkpoints | SavedModel | HDF5 (.h5) |
---|---|---|---|
Format | TensorFlow-specific | Portable format | HDF5 format |
Contains | Weights & training state | Model architecture & weights | Model architecture & weights |
Size | Medium | Larger | Medium |
Primary Use | Resuming training | Deployment | Model sharing |
Compatibility | TensorFlow only | TensorFlow, TF Serving, etc. | Cross-framework |
Summary
TensorFlow Checkpoints are a powerful mechanism for saving training progress and ensuring your work is protected from unexpected interruptions. In this tutorial, we covered:
- Basic checkpoint creation and loading
- Using the ModelCheckpoint callback
- Managing checkpoints with CheckpointManager
- Practical examples of resuming interrupted training
- Best practices for using checkpoints effectively
By implementing checkpoints in your machine learning workflows, you'll ensure that long training sessions can be interrupted and resumed without losing progress, making your development process more resilient and efficient.
Additional Resources and Exercises
Further Learning
Practice Exercises
-
Checkpoint Frequency: Modify the basic checkpoint example to save a checkpoint every 100 training steps rather than every epoch.
-
Selective Loading: Create a model where you load weights from a checkpoint for only specific layers while initializing others randomly.
-
Checkpoint Visualization: Write a script that loads multiple checkpoints and visualizes how model performance changes across checkpoints.
-
Custom Metrics: Create a checkpoint system that saves models based on a custom metric rather than the default accuracy or loss.
-
Cloud Storage: Extend the checkpoint examples to save and load checkpoints from cloud storage (like Google Cloud Storage or AWS S3).
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)