Skip to main content

TensorFlow Early Stopping

Introduction

When training deep learning models, one of the common challenges is determining when to stop training. Training for too few epochs may result in underfitting, while training for too many epochs can lead to overfitting, where the model performs well on training data but poorly on unseen data.

Early stopping is a regularization technique that addresses this challenge by monitoring a model's performance on a validation dataset and stopping the training process when the performance starts to degrade. This approach helps prevent overfitting and saves computational resources by avoiding unnecessary training epochs.

In this tutorial, we'll learn:

  • What early stopping is and why it's important
  • How to implement early stopping in TensorFlow
  • How to customize early stopping parameters
  • How to visualize the effects of early stopping

Understanding Early Stopping

Early stopping works by tracking a specified metric (usually validation loss or accuracy) during training. When this metric stops improving for a specified number of epochs, training is halted. This simple yet effective technique is based on the observation that validation performance typically improves initially but eventually plateaus or degrades as the model begins to overfit.

Early Stopping Visualization

Implementing Early Stopping in TensorFlow

TensorFlow provides the EarlyStopping callback class as part of its keras.callbacks module. Let's look at how to implement early stopping in a simple neural network.

Basic Implementation

python
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
import numpy as np

# Create a simple dataset
X_train = np.random.random((1000, 20))
y_train = np.random.randint(0, 2, (1000, 1))
X_val = np.random.random((300, 20))
y_val = np.random.randint(0, 2, (300, 1))

# Create a simple model
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(20,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])

# Create an early stopping callback
early_stopping = EarlyStopping(
monitor='val_loss', # Metric to monitor
patience=5, # Number of epochs with no improvement after which training will stop
restore_best_weights=True # Restores model weights from the epoch with the best value of the monitored metric
)

# Train the model with early stopping
history = model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
epochs=100, # Maximum number of epochs
callbacks=[early_stopping],
verbose=1
)

# Check how many epochs were actually run
print(f"Training stopped after {len(history.history['loss'])} epochs")

Output:

The output might look something like this:

Epoch 1/100
32/32 [==============================] - 0s 3ms/step - loss: 0.6931 - accuracy: 0.4990 - val_loss: 0.6931 - val_accuracy: 0.5033
Epoch 2/100
32/32 [==============================] - 0s 3ms/step - loss: 0.6930 - accuracy: 0.5090 - val_loss: 0.6930 - val_accuracy: 0.5067
...
Epoch 22/100
32/32 [==============================] - 0s 3ms/step - loss: 0.6895 - accuracy: 0.5410 - val_loss: 0.6926 - val_accuracy: 0.5133
Epoch 23/100
32/32 [==============================] - 0s 3ms/step - loss: 0.6893 - accuracy: 0.5440 - val_loss: 0.6930 - val_accuracy: 0.5100
Training stopped after 23 epochs

Customizing Early Stopping Parameters

The EarlyStopping callback provides several parameters for customization:

python
early_stopping = EarlyStopping(
monitor='val_loss', # Metric to monitor ('val_loss', 'val_accuracy', etc.)
min_delta=0.001, # Minimum change to qualify as improvement
patience=10, # Number of epochs with no improvement after which training will stop
verbose=1, # Whether to print messages
mode='min', # 'min' for metrics that should decrease, 'max' for those that should increase
baseline=None, # Baseline value for the monitored metric
restore_best_weights=True, # Whether to restore model weights from the epoch with the best value
start_from_epoch=0 # Start monitoring from this epoch
)

Common Configurations

Let's examine three common early stopping configurations:

1. Monitoring Validation Accuracy

python
# For classification tasks where we want to maximize accuracy
early_stopping_accuracy = EarlyStopping(
monitor='val_accuracy',
mode='max', # We want accuracy to be maximized
patience=7,
verbose=1,
restore_best_weights=True
)

2. Stricter Early Stopping

python
# For when you want to stop training as soon as performance starts to degrade
strict_early_stopping = EarlyStopping(
monitor='val_loss',
min_delta=0.0001, # Very small improvement threshold
patience=3, # Stop after just 3 epochs without improvement
verbose=1,
restore_best_weights=True
)

3. Early Stopping with Baseline

python
# When you have a minimum performance requirement
baseline_early_stopping = EarlyStopping(
monitor='val_loss',
baseline=0.2, # Only consider stopping if loss is already below 0.2
patience=5,
verbose=1,
restore_best_weights=True
)

Practical Example: MNIST Classification

Let's implement early stopping in a more practical example using the MNIST dataset:

python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt

# Load and preprocess the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train / 255.0
X_test = X_test / 255.0

# Define the model
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(10, activation='softmax')
])

model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Create early stopping callback
early_stopping = EarlyStopping(
monitor='val_accuracy',
mode='max',
patience=5,
verbose=1,
restore_best_weights=True
)

# Train the model with early stopping
history = model.fit(
X_train, y_train,
validation_split=0.2,
epochs=50,
batch_size=128,
callbacks=[early_stopping],
verbose=1
)

# Evaluate the model
test_loss, test_acc = model.evaluate(X_test, y_test)
print(f"Test accuracy: {test_acc:.4f}")

# Visualize training history
def plot_history(history):
# Plot training & validation accuracy
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')

# Plot training & validation loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')

plt.tight_layout()
plt.show()

plot_history(history)

The visualization will show how the training and validation metrics evolved, and where the early stopping kicked in.

Combining Early Stopping with Other Callbacks

Early stopping works well when combined with other callbacks like ModelCheckpoint and ReduceLROnPlateau:

python
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau

# Save the best model during training
model_checkpoint = ModelCheckpoint(
filepath='best_model.h5',
monitor='val_loss',
save_best_only=True,
verbose=1
)

# Reduce learning rate when a metric has stopped improving
reduce_lr = ReduceLROnPlateau(
monitor='val_loss',
factor=0.2,
patience=3,
min_lr=0.0001,
verbose=1
)

# Train with multiple callbacks
history = model.fit(
X_train, y_train,
validation_split=0.2,
epochs=50,
batch_size=128,
callbacks=[early_stopping, model_checkpoint, reduce_lr],
verbose=1
)

When to Use Early Stopping

Early stopping is most useful when:

  1. Training large models: Prevents wasting computational resources
  2. Limited training data: Helps prevent overfitting
  3. Hyperparameter tuning: Allows different model configurations to train only as long as necessary
  4. Production environments: Automatically determines optimal training duration

However, be careful with very small patience values as they might stop training prematurely before the model has had a chance to learn effectively.

Summary

Early stopping is a powerful regularization technique in deep learning that helps prevent overfitting by monitoring validation performance and stopping training when performance begins to deteriorate. Key points to remember:

  • Early stopping works by monitoring a specific metric (usually validation loss or accuracy)
  • The patience parameter defines how many epochs without improvement to wait before stopping
  • Setting restore_best_weights=True ensures you keep the best model seen during training
  • Early stopping can be combined with other callbacks for more effective training

By implementing early stopping in your TensorFlow models, you can:

  • Save computational resources
  • Prevent overfitting
  • Automatically determine the optimal training duration
  • Maintain the best model weights

Additional Resources

Exercises

  1. Modify the MNIST example to compare models trained with and without early stopping. How does the test accuracy differ?
  2. Experiment with different patience values (1, 5, 10, 15) and observe how they affect the final model performance.
  3. Implement early stopping that monitors both validation loss and validation accuracy using multiple callbacks.
  4. Create a custom callback that builds on early stopping to also log when the training would have stopped for different patience values.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)