PyTorch Ray Integration
Introduction
When scaling your PyTorch models to train on large datasets or complex architectures, distributed training becomes essential. One powerful framework for distributed computing is Ray, which integrates seamlessly with PyTorch to provide easy-to-use distributed training capabilities.
Ray is an open-source unified framework for scaling AI and Python applications. By combining Ray with PyTorch, you can:
- Train models across multiple GPUs and machines with minimal code changes
- Run hyperparameter tuning in parallel
- Scale your training pipeline efficiently
In this tutorial, we'll explore how to integrate Ray with PyTorch for distributed training and hyperparameter optimization.
Getting Started
Prerequisites
Before diving in, make sure you have the following packages installed:
pip install torch torchvision ray ray[tune]
Basic Ray Concepts
Ray provides several core abstractions that make distributed computing easier:
- Ray Tasks - Functions that execute asynchronously on the Ray cluster
- Ray Actors - Classes that maintain state across multiple method calls
- Ray Tune - A library for hyperparameter tuning
- Ray Train - A library for distributed model training
Distributed Training with Ray Train
Ray Train makes it simple to train your PyTorch models across multiple GPUs or machines. Let's start with a basic example.
Simple Example: Training a CNN on MNIST
First, let's create a simple PyTorch model for the MNIST dataset:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
Now, let's implement a training function that will work with Ray Train:
import ray
from ray import train
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
def train_func(config):
# Get the Ray Data shard for the current worker
batch_size = config["batch_size"]
lr = config["lr"]
epochs = config["epochs"]
# Set up data loaders
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Ray Train will automatically handle the data sharding
train_dataset = torchvision.datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST('data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# Create the model
model = SimpleCNN()
# Prepare the model for distributed training
model = train.torch.prepare_model(model)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# Training loop
for epoch in range(epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
if i % 100 == 99:
print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}, Acc: {100.*correct/total:.2f}%")
running_loss = 0.0
# Validation
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100. * correct / total
print(f"Epoch {epoch+1}, Test Loss: {test_loss/len(test_loader):.3f}, Test Acc: {accuracy:.2f}%")
# Report metrics to Ray Tune
train.report({"accuracy": accuracy, "loss": test_loss/len(test_loader)})
Now, let's set up the Ray trainer:
# Initialize Ray
ray.init()
# Configure the training
trainer = TorchTrainer(
train_func,
train_loop_config={
"lr": 0.01,
"batch_size": 64,
"epochs": 5
},
scaling_config=ScalingConfig(
num_workers=2, # Number of workers to use
use_gpu=torch.cuda.is_available(), # Use GPUs if available
)
)
# Start training
result = trainer.fit()
print(f"Training completed with accuracy: {result.metrics['accuracy']:.2f}%")
The output should look something like:
Epoch 1, Batch 100, Loss: 0.224, Acc: 93.27%
Epoch 1, Batch 200, Loss: 0.097, Acc: 96.88%
...
Epoch 1, Test Loss: 0.045, Test Acc: 98.72%
...
Epoch 5, Test Loss: 0.029, Test Acc: 99.15%
Training completed with accuracy: 99.15%
Hyperparameter Tuning with Ray Tune
Ray Tune makes it easy to find the optimal hyperparameters for your model. Let's modify our example to include hyperparameter tuning:
from ray import tune
from ray.tune import Tuner
from ray.tune.search.bayesopt import BayesOptSearch
from ray.tune.search import ConcurrencyLimiter
def train_func_for_tune(config):
# Setting up the same train function as above
# but now config will be provided by Ray Tune
# This function is essentially the same as our previous train_func
# except that hyperparameters come from the config
# For brevity, we're not repeating the full code
# The main difference is that we report metrics back to Ray Tune
train.report({"accuracy": accuracy, "loss": test_loss/len(test_loader)})
# Define the search space
search_space = {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([16, 32, 64, 128]),
"epochs": 3 # Fixed number of epochs for each trial
}
# Set up the search algorithm
bayesopt = BayesOptSearch(metric="accuracy", mode="max")
search_alg = ConcurrencyLimiter(bayesopt, max_concurrent=2)
# Configure the training
trainer = TorchTrainer(
train_func_for_tune,
scaling_config=ScalingConfig(
num_workers=2,
use_gpu=torch.cuda.is_available(),
)
)
# Start hyperparameter tuning
tuner = Tuner(
trainer,
param_space={"train_loop_config": search_space},
tune_config=tune.TuneConfig(
search_alg=search_alg,
num_samples=10, # Number of trials
max_concurrent_trials=2 # Run 2 trials in parallel
)
)
results = tuner.fit()
# Get the best result
best_result = results.get_best_result(metric="accuracy", mode="max")
print(f"Best trial config: {best_result.config}")
print(f"Best trial accuracy: {best_result.metrics['accuracy']:.4f}")
The output might look like:
Trial progress:
Trial 1: RUNNING, Trial 2: RUNNING
...
Trial 1: COMPLETED, Trial 2: COMPLETED, Trial 3: RUNNING, Trial 4: RUNNING
...
Best trial config: {'train_loop_config': {'lr': 0.00743, 'batch_size': 64, 'epochs': 3}}
Best trial accuracy: 99.2100
Scaling to Multiple Machines
Ray truly shines when scaling across multiple machines. Let's look at how to set up a multi-node training job:
# On the head node
import ray
# Start Ray with cluster resources
ray.init(address="auto") # Connects to an existing Ray cluster
# Define the training configuration with more resources
trainer = TorchTrainer(
train_func,
train_loop_config={
"lr": 0.01,
"batch_size": 64,
"epochs": 5
},
scaling_config=ScalingConfig(
num_workers=8, # Use 8 workers distributed across the cluster
use_gpu=True, # Use GPUs
resources_per_worker={"CPU": 2, "GPU": 1}, # Each worker gets 2 CPUs and 1 GPU
)
)
# Start training across the cluster
result = trainer.fit()
Real-World Example: Training ResNet on ImageNet
Let's implement a more realistic example using ResNet on a subset of ImageNet:
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import ray
from ray import train
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
def train_resnet(config):
# Set random seeds for reproducibility
torch.manual_seed(42)
# Get hyperparameters from config
batch_size = config["batch_size"]
lr = config["lr"]
epochs = config["epochs"]
model_name = config["model_name"]
# Data transformations
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
# Load datasets (using a subset of ImageNet for demonstration)
# In a real-world scenario, you'd use the full ImageNet dataset
data_path = os.path.expanduser("~/data/tiny-imagenet-200")
train_dataset = datasets.ImageFolder(os.path.join(data_path, 'train'), train_transform)
val_dataset = datasets.ImageFolder(os.path.join(data_path, 'val'), val_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
# Load model
if model_name == "resnet18":
model = models.resnet18(pretrained=False)
elif model_name == "resnet50":
model = models.resnet50(pretrained=False)
else:
raise ValueError(f"Unsupported model: {model_name}")
# Adjust the final layer to match the number of classes in our dataset
num_classes = len(train_dataset.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Prepare model for distributed training
model = train.torch.prepare_model(model)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model.parameters(),
lr=lr,
momentum=0.9,
weight_decay=1e-4
)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# Training loop
for epoch in range(epochs):
# Train phase
model.train()
train_loss = 0.0
correct = 0
total = 0
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
train_acc = 100. * correct / total
train_loss = train_loss / len(train_loader)
# Validation phase
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in val_loader:
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
val_acc = 100. * correct / total
val_loss = val_loss / len(val_loader)
# Adjust learning rate
scheduler.step()
# Log metrics
print(f"Epoch {epoch+1}/{epochs}: "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
# Report metrics to Ray
train.report({
"epoch": epoch + 1,
"train_loss": train_loss,
"train_accuracy": train_acc,
"val_loss": val_loss,
"val_accuracy": val_acc
})
# Initialize Ray
ray.init()
# Configure the training
trainer = TorchTrainer(
train_resnet,
train_loop_config={
"lr": 0.01,
"batch_size": 32,
"epochs": 10,
"model_name": "resnet18"
},
scaling_config=ScalingConfig(
num_workers=4, # Use 4 workers
use_gpu=True, # Use GPUs
)
)
# Start training
result = trainer.fit()
print(f"Training completed with accuracy: {result.metrics['val_accuracy']:.2f}%")
Advanced Topics
Fault Tolerance with Ray
One of Ray's strengths is its fault tolerance capabilities. If a worker node fails during training, Ray can automatically recover and continue the training process:
from ray.tune.execution.sync_executor import SyncExecutor
from ray.tune.stopper import TrialPlateauStopper
# Configure the trainer with fault tolerance options
tuner = Tuner(
trainer,
param_space={"train_loop_config": search_space},
tune_config=tune.TuneConfig(
search_alg=search_alg,
num_samples=10,
max_concurrent_trials=2,
sync_config=tune.SyncConfig(
sync_to_driver=True, # Sync trial results to driver
),
failure_config=tune.FailureConfig(
max_failures=3, # Allow up to 3 failures per trial
),
# Stop trials if no improvement after 5 iterations
stop=TrialPlateauStopper(metric="val_accuracy", std=0.001, num_results=5, grace_period=5)
)
)
Custom Metrics and Callbacks
Ray supports custom metrics and callbacks for monitoring training progress:
from ray.train.callbacks import JsonLoggerCallback, TBXLoggerCallback
# Create a custom callback
class MyCustomCallback(train.callbacks.TrainingCallback):
def __init__(self):
self.best_accuracy = 0
def handle_result(self, results, **info):
current_accuracy = results["val_accuracy"]
if current_accuracy > self.best_accuracy:
self.best_accuracy = current_accuracy
print(f"New best accuracy: {self.best_accuracy:.2f}%")
return results
# Configure the trainer with callbacks
trainer = TorchTrainer(
train_resnet,
train_loop_config={
"lr": 0.01,
"batch_size": 32,
"epochs": 10,
"model_name": "resnet18"
},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
# Add callbacks
run_config=train.RunConfig(
callbacks=[
JsonLoggerCallback(), # Log results as JSON
TBXLoggerCallback(), # Log results to TensorBoard
MyCustomCallback() # Our custom callback
]
)
)
Summary
In this tutorial, you've learned how to integrate PyTorch with Ray for distributed training and hyperparameter tuning. We covered:
- Basic Ray concepts and how they relate to PyTorch
- Setting up distributed training with Ray Train
- Hyperparameter tuning with Ray Tune
- Scaling to multiple machines
- A real-world example using ResNet on ImageNet
- Advanced topics like fault tolerance and custom callbacks
Ray's integration with PyTorch provides a powerful and flexible framework for scaling your deep learning workloads across multiple GPUs and machines, while keeping your code clean and maintainable.
Additional Resources
Exercises
- Modify the MNIST example to use a different model architecture like a deeper CNN or a transformer.
- Implement early stopping with Ray Tune to automatically terminate underperforming trials.
- Explore Ray's Dataset API for efficient data loading in distributed training.
- Implement a custom search algorithm for hyperparameter tuning.
- Scale the ResNet example to use a larger subset of ImageNet with more workers.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)