Skip to main content

TensorFlow Distribution Strategies

Introduction

When training large machine learning models or working with massive datasets, distributing your workload across multiple devices (GPUs, TPUs) or machines can dramatically reduce training time. TensorFlow's Distribution Strategies API provides a high-level interface to distribute training across multiple processing units with minimal changes to your existing code.

In this tutorial, you'll learn:

  • What distribution strategies are and why they're important
  • The different types of distribution strategies available in TensorFlow
  • How to implement each strategy with practical code examples
  • Best practices and common pitfalls when using distribution strategies

What are Distribution Strategies?

Distribution strategies in TensorFlow are a set of APIs that simplify the process of distributing training across multiple GPUs, TPUs, or even machines. They handle the complexity of:

  • Replicating models across devices
  • Distributing input data
  • Aggregating results
  • Managing variables
  • Handling communication between devices

Instead of manually coding these complex operations, distribution strategies let you focus on your model architecture and training logic.

Available Distribution Strategies

TensorFlow offers several distribution strategies optimized for different hardware setups:

  1. MirroredStrategy: For training on multiple GPUs on a single machine
  2. TPUStrategy: For training on TPUs
  3. MultiWorkerMirroredStrategy: For training across multiple machines with multiple GPUs
  4. ParameterServerStrategy: For asynchronous training with parameter server architecture
  5. CentralStorageStrategy: Similar to MirroredStrategy but with variables stored on the CPU

Basic Usage Pattern

All distribution strategies follow a common usage pattern:

  1. Create a strategy instance
  2. Create and compile your model within the strategy's scope
  3. Use the model for training or evaluation as usual

Let's explore each strategy with practical examples.

MirroredStrategy: Training on Multiple GPUs

The most common scenario is training on multiple GPUs within a single machine. MirroredStrategy creates a copy of your model on each GPU and uses all-reduce operations to synchronize gradients.

python
import tensorflow as tf
import numpy as np

# Check if GPUs are available
print(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")

# Create a MirroredStrategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

# Create the model, optimizer, and loss function within the strategy's scope
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Load and preprocess MNIST data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape(-1, 784).astype('float32')
x_test = x_test.reshape(-1, 784).astype('float32')

# Create tf.data.Dataset
batch_size = 64 * strategy.num_replicas_in_sync # Scale batch size by number of replicas
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

# Train the model
model.fit(train_dataset, epochs=5)

# Evaluate the model
eval_loss, eval_accuracy = model.evaluate(test_dataset)
print(f'Evaluation accuracy: {eval_accuracy:.3f}')

Output:

Num GPUs Available: 2
Number of devices: 2
Epoch 1/5
938/938 [==============================] - 3s 2ms/step - loss: 0.2341 - accuracy: 0.9316
Epoch 2/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0973 - accuracy: 0.9702
Epoch 3/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0680 - accuracy: 0.9792
Epoch 4/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0521 - accuracy: 0.9837
Epoch 5/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0412 - accuracy: 0.9872
157/157 [==============================] - 0s 2ms/step - loss: 0.0702 - accuracy: 0.9776
Evaluation accuracy: 0.978

Key Points About MirroredStrategy

  1. Automatic Batch Splitting: The strategy automatically divides your batch across the available GPUs. If you have a batch size of 64 and 2 GPUs, each GPU will process 32 samples.

  2. Batch Size Considerations: It's common practice to increase your batch size proportionally to the number of devices. If your original single-GPU batch size was 64, with 4 GPUs you might use a batch size of 256.

  3. Variable Handling: Variables are mirrored across devices. Updates are synchronized using all-reduce operations.

TPUStrategy: Training on TPUs

If you have access to TPUs (like in Google Colab or Google Cloud), TPUStrategy allows you to utilize their massive processing power:

python
# Only run this on a TPU runtime (like in Colab)
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

strategy = tf.distribute.TPUStrategy(resolver)
print(f"Number of TPU devices: {strategy.num_replicas_in_sync}")

# The rest of the code follows the same pattern as MirroredStrategy
with strategy.scope():
# Create and compile model as before
# ...

# Training and evaluation steps are the same

MultiWorkerMirroredStrategy: Distributed Training Across Machines

For training across multiple machines (each with one or more GPUs), use MultiWorkerMirroredStrategy. This requires setting up a TensorFlow cluster with multiple workers:

python
# This would be set up on each worker in a real distributed setup
# For demonstration purposes, we're showing what a single worker would do

import os
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["localhost:12345", "localhost:12346"]
},
"task": {"type": "worker", "index": 0}
})

strategy = tf.distribute.MultiWorkerMirroredStrategy()

# The rest follows the same pattern
with strategy.scope():
# Create and compile model
# ...

# Training and evaluation

Implementing a Full MultiWorkerMirroredStrategy Example

For a complete implementation, you'll need to run separate processes on different machines with the appropriate TF_CONFIG:

Worker 0 (chief):

python
import json
import os
import tensorflow as tf

# Set the TF_CONFIG for this worker
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["worker0.example.com:12345", "worker1.example.com:12345"]
},
"task": {"type": "worker", "index": 0}
})

# Create the strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()

# Rest of the code follows the same pattern as before
with strategy.scope():
# Create and compile model
# ...

Worker 1:

python
import json
import os
import tensorflow as tf

# Set different task index for this worker
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["worker0.example.com:12345", "worker1.example.com:12345"]
},
"task": {"type": "worker", "index": 1}
})

# The rest of the code is identical to Worker 0

ParameterServerStrategy: Asynchronous Training

ParameterServerStrategy implements the parameter server training architecture, where:

  • Parameter servers store the model variables
  • Workers compute gradients and send updates to parameter servers

This allows for asynchronous updates, which can be beneficial for very large clusters:

python
# Setting TF_CONFIG for a parameter server setup
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["worker0.example.com:12345", "worker1.example.com:12345"],
"ps": ["ps0.example.com:12345", "ps1.example.com:12345"]
},
"task": {"type": "worker", "index": 0} # Change according to the role
})

strategy = tf.distribute.experimental.ParameterServerStrategy()

# The rest follows a similar pattern, but variable handling is different

Custom Training Loops with Distribution Strategies

While using model.fit() is convenient, you might need more control with a custom training loop:

python
import tensorflow as tf

# Create a strategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

# Prepare the dataset
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0

# Create tf.data.Dataset and distribute it
batch_size = 64 * strategy.num_replicas_in_sync
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)
dist_dataset = strategy.experimental_distribute_dataset(train_dataset)

# Create the model within strategy scope
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])

# Define loss and optimizer
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(0.001)

# Define metrics
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

# Define the training step
def train_step(inputs):
features, labels = inputs

with tf.GradientTape() as tape:
predictions = model(features, training=True)
loss = loss_fn(labels, predictions)

gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

train_accuracy.update_state(labels, predictions)
return loss

# Define the distributed training step
@tf.function
def distributed_train_step(inputs):
per_replica_losses = strategy.run(train_step, args=(inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

# Training loop
epochs = 5
steps_per_epoch = len(x_train) // batch_size

for epoch in range(epochs):
total_loss = 0.0
num_batches = 0

# Reset metrics for each epoch
train_accuracy.reset_states()

for step, batch in enumerate(dist_dataset):
loss = distributed_train_step(batch)
total_loss += loss
num_batches += 1

if step % 100 == 0:
print(f"Epoch {epoch+1}, Step {step}, Loss: {loss:.4f}")

# Print epoch results
print(f"Epoch {epoch+1}, Loss: {total_loss/num_batches:.4f}, Accuracy: {train_accuracy.result():.4f}")

Real-World Application: Distributed Image Classification

Let's implement a more realistic image classification model using transfer learning with ResNet and a distribution strategy:

python
import tensorflow as tf
import tensorflow_datasets as tfds

# Create a MirroredStrategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

# Load and prepare the dataset
batch_size = 32 * strategy.num_replicas_in_sync
dataset, info = tfds.load('cats_vs_dogs', with_info=True, as_supervised=True)
train_dataset = dataset['train']

def preprocess(image, label):
image = tf.image.resize(image, (224, 224))
image = tf.keras.applications.resnet50.preprocess_input(image)
return image, label

train_dataset = train_dataset.map(preprocess).shuffle(1000).batch(batch_size)

# Create the model within strategy scope
with strategy.scope():
base_model = tf.keras.applications.ResNet50(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = False # Freeze the base model

model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(
optimizer=tf.keras.optimizers.Adam(0.0001),
loss='binary_crossentropy',
metrics=['accuracy']
)

# Train the model
history = model.fit(train_dataset, epochs=5)

# Fine-tune the model
with strategy.scope():
# Unfreeze some layers
base_model.trainable = True
for layer in base_model.layers[:-10]:
layer.trainable = False

model.compile(
optimizer=tf.keras.optimizers.Adam(0.00001), # Lower learning rate
loss='binary_crossentropy',
metrics=['accuracy']
)

# Continue training
history = model.fit(train_dataset, epochs=5)

Best Practices for Using Distribution Strategies

  1. Scale your batch size proportionally to the number of devices for better performance.

  2. Use tf.data.Dataset API for efficient data loading and preprocessing.

  3. Be aware of differences in variable initialization:

    python
    # Wrong: This will create variables outside the strategy scope
    model = create_model()
    with strategy.scope():
    model.compile(...) # Too late!

    # Correct: Create and compile within strategy scope
    with strategy.scope():
    model = create_model()
    model.compile(...)
  4. Use tf.function to optimize performance, especially with custom training loops.

  5. For multi-worker training, handle checkpoints and model saving carefully:

    python
    # Use ModelCheckpoint callback with proper settings
    callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
    filepath='/tmp/checkpoint',
    save_weights_only=True,
    verbose=1
    )
    ]
    model.fit(..., callbacks=callbacks)
  6. Consider the communication overhead. More devices aren't always better if your model is small.

Common Challenges and Solutions

1. Out of Memory Errors

If you encounter OOM (Out of Memory) errors:

  • Reduce the batch size per GPU
  • Use mixed precision training to reduce memory usage:
python
# Enable mixed precision
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)

# Then create your strategy and model as usual

2. Slow Training Due to Input Pipeline Bottlenecks

If your GPUs are waiting for data:

  • Use tf.data.Dataset.cache() for small datasets
  • Use tf.data.Dataset.prefetch() to prepare data in advance:
python
train_dataset = train_dataset.cache().prefetch(tf.data.AUTOTUNE)

3. Different Results Across Runs

For reproducible results:

  • Set seeds before creating the strategy:
python
def set_seeds(seed=42):
tf.random.set_seed(seed)
np.random.seed(seed)
random.seed(seed)

set_seeds()
strategy = tf.distribute.MirroredStrategy()

Summary

TensorFlow's Distribution Strategies provide a powerful abstraction for distributing training across multiple devices and machines with minimal code changes. The key points to remember are:

  • Use MirroredStrategy for multiple GPUs on a single machine
  • Use TPUStrategy for training on TPUs
  • Use MultiWorkerMirroredStrategy for training across multiple machines
  • Always create and compile your model within the strategy's scope
  • Scale your batch size according to the number of devices
  • Use tf.data API for efficient data loading and preprocessing

By effectively using these strategies, you can dramatically reduce training time for large models and datasets, making it possible to tackle more complex machine learning problems.

Additional Resources

Exercises

  1. Modify the MNIST example to use mixed precision training with MirroredStrategy.
  2. Implement a custom training loop with MultiWorkerMirroredStrategy (you can simulate multiple workers on one machine for testing).
  3. Benchmark the performance difference between training a CNN model on a single GPU vs. multiple GPUs.
  4. Experiment with different batch sizes and measure their impact on training time and model accuracy when using distribution strategies.
  5. Implement a fault-tolerant distributed training pipeline that can recover from worker failures.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)