Skip to main content

TensorFlow Mirrored Strategy

Introduction

TensorFlow's MirroredStrategy is one of the most commonly used distribution strategies for training deep learning models across multiple GPUs within a single machine. As models grow in complexity and datasets increase in size, utilizing multiple GPUs becomes crucial for reducing training time and improving efficiency.

In this tutorial, we'll explore how MirroredStrategy works, when to use it, and how to implement it in your TensorFlow projects. By the end, you'll be able to seamlessly scale your model training across multiple GPUs on a single machine.

What is MirroredStrategy?

MirroredStrategy is a distribution strategy that creates a copy (or "mirror") of your model on each available GPU. During training:

  1. Each GPU processes a different slice of input data in parallel
  2. Gradients from each GPU are combined using an all-reduce algorithm
  3. Model weights are updated synchronously across all devices

This approach is called synchronous training because all devices work together on the same batch of data and wait for each other before proceeding to the next batch.

Setting Up MirroredStrategy

Let's start by importing the necessary libraries and setting up the MirroredStrategy:

python
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, models
import datetime

# Check for available GPUs
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

# Create a MirroredStrategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

Output:

Num GPUs Available:  2
Number of devices: 2

The output will vary based on your hardware setup. In this example, we have 2 GPUs available for training.

Creating a Model with MirroredStrategy

When using a distribution strategy, you need to create your model within the strategy's scope. This ensures that the model's variables are properly mirrored across devices.

Let's create a simple CNN model for MNIST digit classification:

python
# Define batch size and other parameters
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

# Load and preprocess data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add channel dimension
x_train = x_train[..., np.newaxis]
x_test = x_test[..., np.newaxis]

# Create tf.data datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

# Create the model within the strategy scope
with strategy.scope():
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10)
])

model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

# Display model summary
model.summary()

Training the Model

Now, let's train the model across multiple GPUs:

python
# Set up TensorBoard logs
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# Train the model
history = model.fit(
train_dataset,
epochs=5,
validation_data=test_dataset,
callbacks=[tensorboard_callback]
)

Output:

Epoch 1/5
938/938 [==============================] - 8s 9ms/step - loss: 0.2076 - accuracy: 0.9364 - val_loss: 0.0770 - val_accuracy: 0.9763
Epoch 2/5
938/938 [==============================] - 7s 8ms/step - loss: 0.0671 - accuracy: 0.9790 - val_loss: 0.0624 - val_accuracy: 0.9801
Epoch 3/5
938/938 [==============================] - 7s 8ms/step - loss: 0.0498 - accuracy: 0.9844 - val_loss: 0.0541 - val_accuracy: 0.9829
Epoch 4/5
938/938 [==============================] - 7s 8ms/step - loss: 0.0388 - accuracy: 0.9880 - val_loss: 0.0473 - val_accuracy: 0.9844
Epoch 5/5
938/938 [==============================] - 7s 8ms/step - loss: 0.0318 - accuracy: 0.9902 - val_loss: 0.0513 - val_accuracy: 0.9850

Notice how the training is significantly faster because the workload is distributed across multiple GPUs.

Performance Comparison

Let's compare the training time between single-GPU and multi-GPU setups:

python
import time

# Function to measure training time
def measure_training_time(strategy_name, strategy=None):
start_time = time.time()

if strategy:
with strategy.scope():
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10)
])

model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
else:
# Single GPU model
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10)
])

model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

# Create datasets (with appropriate batch sizes)
if strategy:
batch_size = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
else:
batch_size = BATCH_SIZE_PER_REPLICA

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)

# Train for 3 epochs
model.fit(train_dataset, epochs=3, verbose=0)

end_time = time.time()
print(f"{strategy_name} training time: {end_time - start_time:.2f} seconds")

# Measure time for single GPU
with tf.device('/gpu:0'):
measure_training_time("Single GPU", None)

# Measure time for MirroredStrategy
measure_training_time("MirroredStrategy", strategy)

Output:

Single GPU training time: 21.54 seconds
MirroredStrategy training time: 11.87 seconds

The results show that MirroredStrategy significantly reduces training time by leveraging multiple GPUs.

Advanced Usage: Custom Training Loops

For more control over the training process, you can implement custom training loops with MirroredStrategy. This is useful when you need to implement complex training procedures or custom metrics.

python
with strategy.scope():
# Define the model
inputs = tf.keras.layers.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(64, 3, activation='relu')(x)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
outputs = tf.keras.layers.Dense(10)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# Define optimizer and loss function
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Define metrics
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

# Create distributed datasets
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)

# Define the training step
@tf.function
def train_step(inputs):
images, labels = inputs

with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_fn(labels, predictions)

gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

train_accuracy.update_state(labels, predictions)
return loss

# Define the distributed training step
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

# Custom training loop
epochs = 3
for epoch in range(epochs):
total_loss = 0.0
num_batches = 0
train_accuracy.reset_states()

for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1

train_loss = total_loss / num_batches

print(f"Epoch {epoch+1}, Loss: {train_loss.numpy():.4f}, Accuracy: {train_accuracy.result().numpy():.4f}")

Output:

Epoch 1, Loss: 0.2033, Accuracy: 0.9386
Epoch 2, Loss: 0.0648, Accuracy: 0.9798
Epoch 3, Loss: 0.0461, Accuracy: 0.9855

Real-World Application: Training a ResNet on CIFAR-10

Let's implement a more realistic example by training a ResNet model on the CIFAR-10 dataset:

python
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Create tf.data datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

# Define a ResNet block
def resnet_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False):
shortcut = x

if conv_shortcut:
shortcut = layers.Conv2D(filters, 1, strides=stride)(shortcut)
shortcut = layers.BatchNormalization()(shortcut)

x = layers.Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)

x = layers.Conv2D(filters, kernel_size, padding='same')(x)
x = layers.BatchNormalization()(x)

x = layers.add([x, shortcut])
x = layers.ReLU()(x)

return x

# Create ResNet model with strategy
with strategy.scope():
inputs = layers.Input(shape=(32, 32, 3))

# Initial convolution
x = layers.Conv2D(64, 7, strides=2, padding='same')(inputs)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.MaxPooling2D(3, strides=2, padding='same')(x)

# ResNet blocks
x = resnet_block(x, 64, conv_shortcut=True)
x = resnet_block(x, 64)

x = resnet_block(x, 128, stride=2, conv_shortcut=True)
x = resnet_block(x, 128)

x = resnet_block(x, 256, stride=2, conv_shortcut=True)
x = resnet_block(x, 256)

# Final layers
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(10)(x)

resnet_model = tf.keras.Model(inputs, outputs)

# Compile model
resnet_model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

# Train the model
resnet_history = resnet_model.fit(
train_dataset,
epochs=10,
validation_data=test_dataset
)

This example demonstrates how to use MirroredStrategy with a more complex model (ResNet) on a real-world dataset (CIFAR-10). The multi-GPU training significantly reduces the training time for this complex model.

Best Practices for MirroredStrategy

When using MirroredStrategy, consider the following best practices:

  1. Batch Size: Increase your batch size proportionally to the number of GPUs. For example, if your single-GPU batch size is 32 and you're using 4 GPUs, set your total batch size to 128.

  2. Learning Rate: You may need to adjust your learning rate when scaling up batch sizes. A common approach is to use the "square root scaling rule": multiply your learning rate by the square root of the batch size increase.

    python
    # Example of learning rate scaling
    single_gpu_batch_size = 32
    multi_gpu_batch_size = 32 * strategy.num_replicas_in_sync
    single_gpu_learning_rate = 0.001

    # Square root scaling rule
    multi_gpu_learning_rate = single_gpu_learning_rate * (multi_gpu_batch_size / single_gpu_batch_size) ** 0.5

    with strategy.scope():
    model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=multi_gpu_learning_rate),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
    )
  3. Memory Usage: Be aware of memory limitations on your GPUs. If your model is large, you might need to adjust the batch size or model architecture to fit within the available GPU memory.

  4. Dataset Preparation: Use tf.data API for efficient data loading and preprocessing to avoid becoming I/O bound.

    python
    # Optimize dataset for performance
    train_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(10000)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE))
  5. Variable Creation: Always create variables (models, optimizers, etc.) within the strategy's scope.

Common Issues and Solutions

1. Out of Memory (OOM) Errors

If you encounter OOM errors, try:

  • Reducing batch size per GPU
  • Using mixed precision training
  • Simplifying your model architecture
python
# Enable mixed precision training
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

with strategy.scope():
# Model definition here
# ...

2. Training Slower Than Expected

If the speedup is less than expected:

  • Check if your training is I/O bound (slow data loading)
  • Ensure your model is complex enough to benefit from distribution
  • Check for uneven GPU utilization

3. Unexpected Results

If you get different results compared to single-GPU training:

  • Ensure random seeds are set consistently
  • Check if batch normalization is applied correctly across devices
  • Verify your learning rate adjustments

Summary

In this tutorial, we've learned:

  1. How MirroredStrategy enables efficient multi-GPU training on a single machine
  2. How to set up and implement MirroredStrategy for both simple and complex models
  3. How to create custom training loops with MirroredStrategy
  4. Best practices for optimizing multi-GPU training
  5. Common issues and their solutions

TensorFlow's MirroredStrategy provides an easy way to utilize multiple GPUs on a single machine, significantly reducing training time for large models and datasets. By following the best practices outlined in this tutorial, you can efficiently scale up your deep learning training workflows.

Additional Resources

Exercises

  1. Modify the MNIST example to use mixed precision training with MirroredStrategy.
  2. Implement a custom callback that reports the training time per epoch when using MirroredStrategy.
  3. Experiment with different batch sizes and learning rates to find the optimal configuration for your hardware.
  4. Implement a more complex model (like EfficientNet or BERT) with MirroredStrategy.
  5. Compare the performance of MirroredStrategy with other distribution strategies like ParameterServerStrategy or TPUStrategy.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)