TensorFlow Distribution Strategies
Introduction
When training large machine learning models or working with massive datasets, distributing your workload across multiple devices (GPUs, TPUs) or machines can dramatically reduce training time. TensorFlow's Distribution Strategies API provides a high-level interface to distribute training across multiple processing units with minimal changes to your existing code.
In this tutorial, you'll learn:
- What distribution strategies are and why they're important
- The different types of distribution strategies available in TensorFlow
- How to implement each strategy with practical code examples
- Best practices and common pitfalls when using distribution strategies
What are Distribution Strategies?
Distribution strategies in TensorFlow are a set of APIs that simplify the process of distributing training across multiple GPUs, TPUs, or even machines. They handle the complexity of:
- Replicating models across devices
- Distributing input data
- Aggregating results
- Managing variables
- Handling communication between devices
Instead of manually coding these complex operations, distribution strategies let you focus on your model architecture and training logic.
Available Distribution Strategies
TensorFlow offers several distribution strategies optimized for different hardware setups:
MirroredStrategy
: For training on multiple GPUs on a single machineTPUStrategy
: For training on TPUsMultiWorkerMirroredStrategy
: For training across multiple machines with multiple GPUsParameterServerStrategy
: For asynchronous training with parameter server architectureCentralStorageStrategy
: Similar toMirroredStrategy
but with variables stored on the CPU
Basic Usage Pattern
All distribution strategies follow a common usage pattern:
- Create a strategy instance
- Create and compile your model within the strategy's scope
- Use the model for training or evaluation as usual
Let's explore each strategy with practical examples.
MirroredStrategy: Training on Multiple GPUs
The most common scenario is training on multiple GPUs within a single machine. MirroredStrategy
creates a copy of your model on each GPU and uses all-reduce operations to synchronize gradients.
import tensorflow as tf
import numpy as np
# Check if GPUs are available
print(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")
# Create a MirroredStrategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
# Create the model, optimizer, and loss function within the strategy's scope
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Load and preprocess MNIST data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape(-1, 784).astype('float32')
x_test = x_test.reshape(-1, 784).astype('float32')
# Create tf.data.Dataset
batch_size = 64 * strategy.num_replicas_in_sync # Scale batch size by number of replicas
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
# Train the model
model.fit(train_dataset, epochs=5)
# Evaluate the model
eval_loss, eval_accuracy = model.evaluate(test_dataset)
print(f'Evaluation accuracy: {eval_accuracy:.3f}')
Output:
Num GPUs Available: 2
Number of devices: 2
Epoch 1/5
938/938 [==============================] - 3s 2ms/step - loss: 0.2341 - accuracy: 0.9316
Epoch 2/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0973 - accuracy: 0.9702
Epoch 3/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0680 - accuracy: 0.9792
Epoch 4/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0521 - accuracy: 0.9837
Epoch 5/5
938/938 [==============================] - 2s 2ms/step - loss: 0.0412 - accuracy: 0.9872
157/157 [==============================] - 0s 2ms/step - loss: 0.0702 - accuracy: 0.9776
Evaluation accuracy: 0.978
Key Points About MirroredStrategy
-
Automatic Batch Splitting: The strategy automatically divides your batch across the available GPUs. If you have a batch size of 64 and 2 GPUs, each GPU will process 32 samples.
-
Batch Size Considerations: It's common practice to increase your batch size proportionally to the number of devices. If your original single-GPU batch size was 64, with 4 GPUs you might use a batch size of 256.
-
Variable Handling: Variables are mirrored across devices. Updates are synchronized using all-reduce operations.
TPUStrategy: Training on TPUs
If you have access to TPUs (like in Google Colab or Google Cloud), TPUStrategy
allows you to utilize their massive processing power:
# Only run this on a TPU runtime (like in Colab)
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
print(f"Number of TPU devices: {strategy.num_replicas_in_sync}")
# The rest of the code follows the same pattern as MirroredStrategy
with strategy.scope():
# Create and compile model as before
# ...
# Training and evaluation steps are the same
MultiWorkerMirroredStrategy: Distributed Training Across Machines
For training across multiple machines (each with one or more GPUs), use MultiWorkerMirroredStrategy
. This requires setting up a TensorFlow cluster with multiple workers:
# This would be set up on each worker in a real distributed setup
# For demonstration purposes, we're showing what a single worker would do
import os
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["localhost:12345", "localhost:12346"]
},
"task": {"type": "worker", "index": 0}
})
strategy = tf.distribute.MultiWorkerMirroredStrategy()
# The rest follows the same pattern
with strategy.scope():
# Create and compile model
# ...
# Training and evaluation
Implementing a Full MultiWorkerMirroredStrategy Example
For a complete implementation, you'll need to run separate processes on different machines with the appropriate TF_CONFIG:
Worker 0 (chief):
import json
import os
import tensorflow as tf
# Set the TF_CONFIG for this worker
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["worker0.example.com:12345", "worker1.example.com:12345"]
},
"task": {"type": "worker", "index": 0}
})
# Create the strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()
# Rest of the code follows the same pattern as before
with strategy.scope():
# Create and compile model
# ...
Worker 1:
import json
import os
import tensorflow as tf
# Set different task index for this worker
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["worker0.example.com:12345", "worker1.example.com:12345"]
},
"task": {"type": "worker", "index": 1}
})
# The rest of the code is identical to Worker 0
ParameterServerStrategy: Asynchronous Training
ParameterServerStrategy
implements the parameter server training architecture, where:
- Parameter servers store the model variables
- Workers compute gradients and send updates to parameter servers
This allows for asynchronous updates, which can be beneficial for very large clusters:
# Setting TF_CONFIG for a parameter server setup
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["worker0.example.com:12345", "worker1.example.com:12345"],
"ps": ["ps0.example.com:12345", "ps1.example.com:12345"]
},
"task": {"type": "worker", "index": 0} # Change according to the role
})
strategy = tf.distribute.experimental.ParameterServerStrategy()
# The rest follows a similar pattern, but variable handling is different
Custom Training Loops with Distribution Strategies
While using model.fit()
is convenient, you might need more control with a custom training loop:
import tensorflow as tf
# Create a strategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
# Prepare the dataset
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
# Create tf.data.Dataset and distribute it
batch_size = 64 * strategy.num_replicas_in_sync
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size)
dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
# Create the model within strategy scope
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])
# Define loss and optimizer
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(0.001)
# Define metrics
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
# Define the training step
def train_step(inputs):
features, labels = inputs
with tf.GradientTape() as tape:
predictions = model(features, training=True)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss
# Define the distributed training step
@tf.function
def distributed_train_step(inputs):
per_replica_losses = strategy.run(train_step, args=(inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
# Training loop
epochs = 5
steps_per_epoch = len(x_train) // batch_size
for epoch in range(epochs):
total_loss = 0.0
num_batches = 0
# Reset metrics for each epoch
train_accuracy.reset_states()
for step, batch in enumerate(dist_dataset):
loss = distributed_train_step(batch)
total_loss += loss
num_batches += 1
if step % 100 == 0:
print(f"Epoch {epoch+1}, Step {step}, Loss: {loss:.4f}")
# Print epoch results
print(f"Epoch {epoch+1}, Loss: {total_loss/num_batches:.4f}, Accuracy: {train_accuracy.result():.4f}")
Real-World Application: Distributed Image Classification
Let's implement a more realistic image classification model using transfer learning with ResNet and a distribution strategy:
import tensorflow as tf
import tensorflow_datasets as tfds
# Create a MirroredStrategy
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
# Load and prepare the dataset
batch_size = 32 * strategy.num_replicas_in_sync
dataset, info = tfds.load('cats_vs_dogs', with_info=True, as_supervised=True)
train_dataset = dataset['train']
def preprocess(image, label):
image = tf.image.resize(image, (224, 224))
image = tf.keras.applications.resnet50.preprocess_input(image)
return image, label
train_dataset = train_dataset.map(preprocess).shuffle(1000).batch(batch_size)
# Create the model within strategy scope
with strategy.scope():
base_model = tf.keras.applications.ResNet50(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = False # Freeze the base model
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.0001),
loss='binary_crossentropy',
metrics=['accuracy']
)
# Train the model
history = model.fit(train_dataset, epochs=5)
# Fine-tune the model
with strategy.scope():
# Unfreeze some layers
base_model.trainable = True
for layer in base_model.layers[:-10]:
layer.trainable = False
model.compile(
optimizer=tf.keras.optimizers.Adam(0.00001), # Lower learning rate
loss='binary_crossentropy',
metrics=['accuracy']
)
# Continue training
history = model.fit(train_dataset, epochs=5)
Best Practices for Using Distribution Strategies
-
Scale your batch size proportionally to the number of devices for better performance.
-
Use
tf.data.Dataset
API for efficient data loading and preprocessing. -
Be aware of differences in variable initialization:
python# Wrong: This will create variables outside the strategy scope
model = create_model()
with strategy.scope():
model.compile(...) # Too late!
# Correct: Create and compile within strategy scope
with strategy.scope():
model = create_model()
model.compile(...) -
Use
tf.function
to optimize performance, especially with custom training loops. -
For multi-worker training, handle checkpoints and model saving carefully:
python# Use ModelCheckpoint callback with proper settings
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath='/tmp/checkpoint',
save_weights_only=True,
verbose=1
)
]
model.fit(..., callbacks=callbacks) -
Consider the communication overhead. More devices aren't always better if your model is small.
Common Challenges and Solutions
1. Out of Memory Errors
If you encounter OOM (Out of Memory) errors:
- Reduce the batch size per GPU
- Use mixed precision training to reduce memory usage:
# Enable mixed precision
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
# Then create your strategy and model as usual
2. Slow Training Due to Input Pipeline Bottlenecks
If your GPUs are waiting for data:
- Use
tf.data.Dataset.cache()
for small datasets - Use
tf.data.Dataset.prefetch()
to prepare data in advance:
train_dataset = train_dataset.cache().prefetch(tf.data.AUTOTUNE)
3. Different Results Across Runs
For reproducible results:
- Set seeds before creating the strategy:
def set_seeds(seed=42):
tf.random.set_seed(seed)
np.random.seed(seed)
random.seed(seed)
set_seeds()
strategy = tf.distribute.MirroredStrategy()
Summary
TensorFlow's Distribution Strategies provide a powerful abstraction for distributing training across multiple devices and machines with minimal code changes. The key points to remember are:
- Use
MirroredStrategy
for multiple GPUs on a single machine - Use
TPUStrategy
for training on TPUs - Use
MultiWorkerMirroredStrategy
for training across multiple machines - Always create and compile your model within the strategy's scope
- Scale your batch size according to the number of devices
- Use
tf.data
API for efficient data loading and preprocessing
By effectively using these strategies, you can dramatically reduce training time for large models and datasets, making it possible to tackle more complex machine learning problems.
Additional Resources
- TensorFlow Distribution Strategy Guide
- TensorFlow Multi-worker Training with Keras
- Training with TPUs
Exercises
- Modify the MNIST example to use mixed precision training with
MirroredStrategy
. - Implement a custom training loop with
MultiWorkerMirroredStrategy
(you can simulate multiple workers on one machine for testing). - Benchmark the performance difference between training a CNN model on a single GPU vs. multiple GPUs.
- Experiment with different batch sizes and measure their impact on training time and model accuracy when using distribution strategies.
- Implement a fault-tolerant distributed training pipeline that can recover from worker failures.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)