TensorFlow Mirrored Strategy
Introduction
TensorFlow's MirroredStrategy
is one of the most commonly used distribution strategies for training deep learning models across multiple GPUs within a single machine. As models grow in complexity and datasets increase in size, utilizing multiple GPUs becomes crucial for reducing training time and improving efficiency.
In this tutorial, we'll explore how MirroredStrategy
works, when to use it, and how to implement it in your TensorFlow projects. By the end, you'll be able to seamlessly scale your model training across multiple GPUs on a single machine.
What is MirroredStrategy?
MirroredStrategy
is a distribution strategy that creates a copy (or "mirror") of your model on each available GPU. During training:
- Each GPU processes a different slice of input data in parallel
- Gradients from each GPU are combined using an all-reduce algorithm
- Model weights are updated synchronously across all devices
This approach is called synchronous training because all devices work together on the same batch of data and wait for each other before proceeding to the next batch.
Setting Up MirroredStrategy
Let's start by importing the necessary libraries and setting up the MirroredStrategy
:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, models
import datetime
# Check for available GPUs
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
# Create a MirroredStrategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
Output:
Num GPUs Available: 2
Number of devices: 2
The output will vary based on your hardware setup. In this example, we have 2 GPUs available for training.
Creating a Model with MirroredStrategy
When using a distribution strategy, you need to create your model within the strategy's scope. This ensures that the model's variables are properly mirrored across devices.
Let's create a simple CNN model for MNIST digit classification:
# Define batch size and other parameters
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
# Load and preprocess data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add channel dimension
x_train = x_train[..., np.newaxis]
x_test = x_test[..., np.newaxis]
# Create tf.data datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)
# Create the model within the strategy scope
with strategy.scope():
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# Display model summary
model.summary()
Training the Model
Now, let's train the model across multiple GPUs:
# Set up TensorBoard logs
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
# Train the model
history = model.fit(
train_dataset,
epochs=5,
validation_data=test_dataset,
callbacks=[tensorboard_callback]
)
Output:
Epoch 1/5
938/938 [==============================] - 8s 9ms/step - loss: 0.2076 - accuracy: 0.9364 - val_loss: 0.0770 - val_accuracy: 0.9763
Epoch 2/5
938/938 [==============================] - 7s 8ms/step - loss: 0.0671 - accuracy: 0.9790 - val_loss: 0.0624 - val_accuracy: 0.9801
Epoch 3/5
938/938 [==============================] - 7s 8ms/step - loss: 0.0498 - accuracy: 0.9844 - val_loss: 0.0541 - val_accuracy: 0.9829
Epoch 4/5
938/938 [==============================] - 7s 8ms/step - loss: 0.0388 - accuracy: 0.9880 - val_loss: 0.0473 - val_accuracy: 0.9844
Epoch 5/5
938/938 [==============================] - 7s 8ms/step - loss: 0.0318 - accuracy: 0.9902 - val_loss: 0.0513 - val_accuracy: 0.9850
Notice how the training is significantly faster because the workload is distributed across multiple GPUs.
Performance Comparison
Let's compare the training time between single-GPU and multi-GPU setups:
import time
# Function to measure training time
def measure_training_time(strategy_name, strategy=None):
start_time = time.time()
if strategy:
with strategy.scope():
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
else:
# Single GPU model
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# Create datasets (with appropriate batch sizes)
if strategy:
batch_size = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
else:
batch_size = BATCH_SIZE_PER_REPLICA
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)
# Train for 3 epochs
model.fit(train_dataset, epochs=3, verbose=0)
end_time = time.time()
print(f"{strategy_name} training time: {end_time - start_time:.2f} seconds")
# Measure time for single GPU
with tf.device('/gpu:0'):
measure_training_time("Single GPU", None)
# Measure time for MirroredStrategy
measure_training_time("MirroredStrategy", strategy)
Output:
Single GPU training time: 21.54 seconds
MirroredStrategy training time: 11.87 seconds
The results show that MirroredStrategy
significantly reduces training time by leveraging multiple GPUs.
Advanced Usage: Custom Training Loops
For more control over the training process, you can implement custom training loops with MirroredStrategy
. This is useful when you need to implement complex training procedures or custom metrics.
with strategy.scope():
# Define the model
inputs = tf.keras.layers.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(64, 3, activation='relu')(x)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
outputs = tf.keras.layers.Dense(10)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
# Define optimizer and loss function
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Define metrics
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
# Create distributed datasets
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
# Define the training step
@tf.function
def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss
# Define the distributed training step
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
# Custom training loop
epochs = 3
for epoch in range(epochs):
total_loss = 0.0
num_batches = 0
train_accuracy.reset_states()
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
train_loss = total_loss / num_batches
print(f"Epoch {epoch+1}, Loss: {train_loss.numpy():.4f}, Accuracy: {train_accuracy.result().numpy():.4f}")
Output:
Epoch 1, Loss: 0.2033, Accuracy: 0.9386
Epoch 2, Loss: 0.0648, Accuracy: 0.9798
Epoch 3, Loss: 0.0461, Accuracy: 0.9855
Real-World Application: Training a ResNet on CIFAR-10
Let's implement a more realistic example by training a ResNet model on the CIFAR-10 dataset:
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# Create tf.data datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)
# Define a ResNet block
def resnet_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False):
shortcut = x
if conv_shortcut:
shortcut = layers.Conv2D(filters, 1, strides=stride)(shortcut)
shortcut = layers.BatchNormalization()(shortcut)
x = layers.Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2D(filters, kernel_size, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.add([x, shortcut])
x = layers.ReLU()(x)
return x
# Create ResNet model with strategy
with strategy.scope():
inputs = layers.Input(shape=(32, 32, 3))
# Initial convolution
x = layers.Conv2D(64, 7, strides=2, padding='same')(inputs)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
# ResNet blocks
x = resnet_block(x, 64, conv_shortcut=True)
x = resnet_block(x, 64)
x = resnet_block(x, 128, stride=2, conv_shortcut=True)
x = resnet_block(x, 128)
x = resnet_block(x, 256, stride=2, conv_shortcut=True)
x = resnet_block(x, 256)
# Final layers
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(10)(x)
resnet_model = tf.keras.Model(inputs, outputs)
# Compile model
resnet_model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# Train the model
resnet_history = resnet_model.fit(
train_dataset,
epochs=10,
validation_data=test_dataset
)
This example demonstrates how to use MirroredStrategy
with a more complex model (ResNet) on a real-world dataset (CIFAR-10). The multi-GPU training significantly reduces the training time for this complex model.
Best Practices for MirroredStrategy
When using MirroredStrategy
, consider the following best practices:
-
Batch Size: Increase your batch size proportionally to the number of GPUs. For example, if your single-GPU batch size is 32 and you're using 4 GPUs, set your total batch size to 128.
-
Learning Rate: You may need to adjust your learning rate when scaling up batch sizes. A common approach is to use the "square root scaling rule": multiply your learning rate by the square root of the batch size increase.
python# Example of learning rate scaling
single_gpu_batch_size = 32
multi_gpu_batch_size = 32 * strategy.num_replicas_in_sync
single_gpu_learning_rate = 0.001
# Square root scaling rule
multi_gpu_learning_rate = single_gpu_learning_rate * (multi_gpu_batch_size / single_gpu_batch_size) ** 0.5
with strategy.scope():
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=multi_gpu_learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
) -
Memory Usage: Be aware of memory limitations on your GPUs. If your model is large, you might need to adjust the batch size or model architecture to fit within the available GPU memory.
-
Dataset Preparation: Use
tf.data
API for efficient data loading and preprocessing to avoid becoming I/O bound.python# Optimize dataset for performance
train_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(10000)
.batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE)) -
Variable Creation: Always create variables (models, optimizers, etc.) within the strategy's scope.
Common Issues and Solutions
1. Out of Memory (OOM) Errors
If you encounter OOM errors, try:
- Reducing batch size per GPU
- Using mixed precision training
- Simplifying your model architecture
# Enable mixed precision training
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
with strategy.scope():
# Model definition here
# ...
2. Training Slower Than Expected
If the speedup is less than expected:
- Check if your training is I/O bound (slow data loading)
- Ensure your model is complex enough to benefit from distribution
- Check for uneven GPU utilization
3. Unexpected Results
If you get different results compared to single-GPU training:
- Ensure random seeds are set consistently
- Check if batch normalization is applied correctly across devices
- Verify your learning rate adjustments
Summary
In this tutorial, we've learned:
- How
MirroredStrategy
enables efficient multi-GPU training on a single machine - How to set up and implement
MirroredStrategy
for both simple and complex models - How to create custom training loops with
MirroredStrategy
- Best practices for optimizing multi-GPU training
- Common issues and their solutions
TensorFlow's MirroredStrategy
provides an easy way to utilize multiple GPUs on a single machine, significantly reducing training time for large models and datasets. By following the best practices outlined in this tutorial, you can efficiently scale up your deep learning training workflows.
Additional Resources
Exercises
- Modify the MNIST example to use mixed precision training with
MirroredStrategy
. - Implement a custom callback that reports the training time per epoch when using
MirroredStrategy
. - Experiment with different batch sizes and learning rates to find the optimal configuration for your hardware.
- Implement a more complex model (like EfficientNet or BERT) with
MirroredStrategy
. - Compare the performance of
MirroredStrategy
with other distribution strategies likeParameterServerStrategy
orTPUStrategy
.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)