PyTorch Distributed Evaluation
Introduction
When training models at scale with PyTorch's distributed capabilities, it's equally important to perform evaluation efficiently. Distributed evaluation allows you to validate your model's performance across multiple GPUs and nodes, significantly reducing the time needed to evaluate large models on substantial validation datasets.
In this guide, we'll explore how to implement distributed evaluation strategies that complement your distributed training workflows. You'll learn how to properly evaluate models across multiple devices while ensuring accurate aggregation of metrics.
Why Distributed Evaluation Matters
Evaluation can become a bottleneck when working with:
- Large models that require significant memory
- Extensive validation datasets
- Time-sensitive applications where quick feedback is necessary
By distributing the evaluation workload, you can:
- Process more samples in parallel
- Reduce evaluation time significantly
- Use the same distributed infrastructure you already set up for training
Basic Concepts of Distributed Evaluation
Distributed evaluation follows many of the same principles as distributed training, but with some key differences:
- No backward pass or gradient computation is needed
- Metrics need to be properly aggregated across all processes
- For certain metrics (like accuracy), special care is needed for correct calculation
Setting Up Distributed Evaluation
Let's start with a basic setup for distributed evaluation:
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
def setup(rank, world_size):
"""Initialize the distributed environment."""
dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo",
rank=rank, world_size=world_size)
def cleanup():
"""Clean up the distributed environment."""
dist.destroy_process_group()
def distributed_evaluate(rank, world_size, model, val_dataset):
# Initialize the distributed process group
setup(rank, world_size)
# Move model to the appropriate device
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Wrap the model with DDP
model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)
# Create distributed sampler for validation data
val_sampler = DistributedSampler(
val_dataset,
num_replicas=world_size,
rank=rank,
shuffle=False # Important for evaluation: maintain order
)
val_loader = DataLoader(
val_dataset,
batch_size=32,
sampler=val_sampler,
num_workers=2,
pin_memory=True
)
# Now run evaluation
model.eval()
with torch.no_grad():
# Your evaluation logic goes here
pass
cleanup()
Implementing the Evaluation Loop
Now let's implement a complete evaluation function with proper metric aggregation:
def distributed_evaluate(rank, world_size, model, val_dataset, criterion):
# Setup and data loading (as above)
setup(rank, world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler, num_workers=2)
# Evaluation loop
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
# Calculate loss
loss = criterion(outputs, targets)
val_loss += loss.item()
# Calculate accuracy
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
# Aggregate metrics from all processes
val_loss_tensor = torch.tensor([val_loss], device=device)
correct_tensor = torch.tensor([correct], device=device)
total_tensor = torch.tensor([total], device=device)
# Sum the metrics across all processes
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
# Calculate the global metrics
val_loss = val_loss_tensor.item() / world_size # Average loss across processes
accuracy = 100.0 * correct_tensor.item() / total_tensor.item()
# Print results only on the main process
if rank == 0:
print(f"Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%")
cleanup()
return val_loss, accuracy
Running the Distributed Evaluation
To run our distributed evaluation, we need to launch it using PyTorch's multiprocessing:
def main():
# Create model and dataset
model = YourModel()
val_dataset = YourDataset()
criterion = nn.CrossEntropyLoss()
# Determine world size and run distributed evaluation
world_size = torch.cuda.device_count() if torch.cuda.is_available() else 1
if world_size > 1:
print(f"Using {world_size} GPUs for distributed evaluation")
mp.spawn(
distributed_evaluate,
args=(world_size, model, val_dataset, criterion),
nprocs=world_size,
join=True
)
else:
# Fall back to single-device evaluation
print("Using single device for evaluation")
distributed_evaluate(0, 1, model, val_dataset, criterion)
if __name__ == "__main__":
main()
Advanced Techniques for Distributed Evaluation
1. Handling Custom Metrics
For custom or complex metrics like Mean Average Precision (MAP) or F1-score, you'll need to gather all predictions and targets on a single process:
def distributed_evaluate_with_custom_metrics(rank, world_size, model, val_dataset):
# Setup as before
# ...
all_predictions = []
all_targets = []
with torch.no_grad():
for inputs, targets in val_loader:
inputs = inputs.to(device)
outputs = model(inputs)
predictions = torch.softmax(outputs, dim=1)
# Store predictions and targets
all_predictions.append(predictions.cpu())
all_targets.append(targets)
# Concatenate all predictions and targets from this process
all_predictions = torch.cat(all_predictions)
all_targets = torch.cat(all_targets)
# Gather predictions and targets from all processes to rank 0
gathered_predictions = [torch.zeros_like(all_predictions) for _ in range(world_size)]
gathered_targets = [torch.zeros_like(all_targets) for _ in range(world_size)]
dist.all_gather(gathered_predictions, all_predictions)
dist.all_gather(gathered_targets, all_targets)
# Process 0 calculates the metrics
if rank == 0:
# Concatenate all gathered tensors
full_predictions = torch.cat(gathered_predictions)
full_targets = torch.cat(gathered_targets)
# Now calculate advanced metrics
precision = precision_score(full_targets.numpy(), full_predictions.argmax(dim=1).numpy(), average='macro')
recall = recall_score(full_targets.numpy(), full_predictions.argmax(dim=1).numpy(), average='macro')
f1 = f1_score(full_targets.numpy(), full_predictions.argmax(dim=1).numpy(), average='macro')
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")
cleanup()
2. Efficient Evaluation with Checkpointing
When evaluating large models, loading the model on each process can be memory-intensive. Here's a more efficient way using checkpoints:
def evaluate_from_checkpoint(rank, world_size, checkpoint_path, model_class, val_dataset):
setup(rank, world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
# Load model from checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
# Wrap with DDP for consistent behavior with training
model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)
# Evaluation as before
# ...
cleanup()
3. Zero Redundancy Optimizer Integration
If you're using ZeRO (Zero Redundancy Optimizer) for training, you can integrate it with evaluation:
import torch
from torch.distributed.optim import ZeroRedundancyOptimizer
def zero_evaluate(rank, world_size, model, val_dataset, optimizer_class=torch.optim.Adam):
setup(rank, world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Initialize ZeRO optimizer (needed to properly shard the model)
optimizer = ZeroRedundancyOptimizer(
model.parameters(),
optimizer_class=optimizer_class,
lr=0.001 # Dummy value, not used for evaluation
)
# Create distributed sampler for validation data
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)
val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler)
# Evaluation with ZeRO
model.eval()
with torch.no_grad():
# Evaluation logic as before
# ...
cleanup()
Real-World Example: Evaluating BERT on SQuAD Dataset
Let's implement distributed evaluation for a BERT model on the Stanford Question Answering Dataset (SQuAD):
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from transformers import BertForQuestionAnswering, BertTokenizer, squad_metrics
def evaluate_bert_squad(rank, world_size, checkpoint_path):
# Initialize distributed environment
setup(rank, world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
# Load your fine-tuned model checkpoint if available
if checkpoint_path:
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)
# Load SQuAD validation dataset
from datasets import load_dataset
squad_val = load_dataset('squad', split='validation')
# Preprocess the dataset
def preprocess_squad(examples):
questions = [q.strip() for q in examples["question"]]
inputs = tokenizer(
questions,
examples["context"],
max_length=384,
truncation="only_second",
stride=128,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
offset_mapping = inputs.pop("offset_mapping")
sample_map = inputs.pop("overflow_to_sample_mapping")
start_positions = []
end_positions = []
for i, offset in enumerate(offset_mapping):
sample_idx = sample_map[i]
answer = examples["answers"][sample_idx]
start_char = answer["answer_start"][0]
end_char = answer["answer_start"][0] + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)
# Find start and end token positions
start_token = end_token = None
for idx, (o, s) in enumerate(zip(offset, sequence_ids)):
if s != 1:
continue
if start_token is None and o[0] <= start_char < o[1]:
start_token = idx
if end_token is None and o[0] <= end_char <= o[1]:
end_token = idx
if start_token is None:
start_token = 0
if end_token is None:
end_token = 0
start_positions.append(start_token)
end_positions.append(end_token)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
processed_squad = squad_val.map(
preprocess_squad,
batched=True,
remove_columns=squad_val.column_names,
)
# Create the DataLoader
val_sampler = DistributedSampler(
processed_squad,
num_replicas=world_size,
rank=rank,
shuffle=False
)
val_loader = DataLoader(
processed_squad,
batch_size=16,
sampler=val_sampler,
collate_fn=lambda batch: {k: torch.stack([torch.tensor(x[k]) for x in batch]) for k in batch[0]},
)
# Evaluation loop
model.eval()
all_start_logits = []
all_end_logits = []
with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
all_start_logits.append(start_logits.cpu())
all_end_logits.append(end_logits.cpu())
# Concatenate all logits
all_start_logits = torch.cat(all_start_logits, dim=0)
all_end_logits = torch.cat(all_end_logits, dim=0)
# Gather results from all processes
gathered_start_logits = [torch.zeros_like(all_start_logits) for _ in range(world_size)]
gathered_end_logits = [torch.zeros_like(all_end_logits) for _ in range(world_size)]
dist.all_gather(gathered_start_logits, all_start_logits)
dist.all_gather(gathered_end_logits, all_end_logits)
# Now process 0 computes the final metrics
if rank == 0:
full_start_logits = torch.cat(gathered_start_logits)
full_end_logits = torch.cat(gathered_end_logits)
# Convert to predictions
predictions = []
for i in range(len(full_start_logits)):
start_idx = torch.argmax(full_start_logits[i])
end_idx = torch.argmax(full_end_logits[i])
predictions.append((start_idx.item(), end_idx.item()))
# Calculate F1 and Exact Match scores
# Note: This is a simplified version - real implementation would need
# to map these predictions back to the original context
print("Evaluation complete! Process predictions to get F1 and EM scores.")
cleanup()
Best Practices for Distributed Evaluation
-
Keep Determinism:
- Set seeds in each process to ensure reproducibility
- Use
shuffle=False
in theDistributedSampler
-
Memory Efficiency:
- Clear GPU memory between batches for large evaluations
- Use mixed precision when applicable
python# Example of memory-efficient evaluation
with torch.no_grad(), torch.cuda.amp.autocast(enabled=use_amp):
for batch in val_loader:
outputs = model(batch)
# Process outputs
torch.cuda.empty_cache() # Clear cache if needed -
Synchronize Metrics Properly:
- Use
all_reduce
to aggregate metrics across processes - Be careful with reduced metrics - simple averaging isn't always correct
- Use
-
Handle Edge Cases:
- Different processes may have different batch counts if the dataset size isn't divisible by world_size
- Consider weighted averaging for metrics when different processes have different dataset sizes
Summary
Distributed evaluation is a crucial part of the deep learning workflow, allowing efficient validation of models trained in distributed environments. In this guide, we've covered:
- Basic setup for distributed evaluation
- Implementing proper metric aggregation across processes
- Advanced techniques for handling custom metrics and large models
- Real-world examples with complex models like BERT
- Best practices to ensure correct and efficient evaluation
By applying these techniques, you can significantly speed up your model validation process while maintaining accuracy in your metrics.
Additional Resources
- PyTorch Distributed Documentation
- PyTorch Performance Tuning Guide
- NVIDIA Developer Blog: Scaling PyTorch
Exercises
- Implement distributed evaluation for an image classification model using ResNet on ImageNet.
- Modify the BERT evaluation example to compute and report F1 and Exact Match scores correctly.
- Implement distributed evaluation with early stopping based on validation metrics.
- Create a function that can evaluate a model on multiple GPUs but on a single node without using
mp.spawn()
. - Extend the distributed evaluation framework to log metrics to a monitoring system like TensorBoard or Weights & Biases.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)