Skip to main content

PyTorch Distributed Evaluation

Introduction

When training models at scale with PyTorch's distributed capabilities, it's equally important to perform evaluation efficiently. Distributed evaluation allows you to validate your model's performance across multiple GPUs and nodes, significantly reducing the time needed to evaluate large models on substantial validation datasets.

In this guide, we'll explore how to implement distributed evaluation strategies that complement your distributed training workflows. You'll learn how to properly evaluate models across multiple devices while ensuring accurate aggregation of metrics.

Why Distributed Evaluation Matters

Evaluation can become a bottleneck when working with:

  • Large models that require significant memory
  • Extensive validation datasets
  • Time-sensitive applications where quick feedback is necessary

By distributing the evaluation workload, you can:

  • Process more samples in parallel
  • Reduce evaluation time significantly
  • Use the same distributed infrastructure you already set up for training

Basic Concepts of Distributed Evaluation

Distributed evaluation follows many of the same principles as distributed training, but with some key differences:

  1. No backward pass or gradient computation is needed
  2. Metrics need to be properly aggregated across all processes
  3. For certain metrics (like accuracy), special care is needed for correct calculation

Setting Up Distributed Evaluation

Let's start with a basic setup for distributed evaluation:

python
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

def setup(rank, world_size):
"""Initialize the distributed environment."""
dist.init_process_group("nccl" if torch.cuda.is_available() else "gloo",
rank=rank, world_size=world_size)

def cleanup():
"""Clean up the distributed environment."""
dist.destroy_process_group()

def distributed_evaluate(rank, world_size, model, val_dataset):
# Initialize the distributed process group
setup(rank, world_size)

# Move model to the appropriate device
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Wrap the model with DDP
model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)

# Create distributed sampler for validation data
val_sampler = DistributedSampler(
val_dataset,
num_replicas=world_size,
rank=rank,
shuffle=False # Important for evaluation: maintain order
)

val_loader = DataLoader(
val_dataset,
batch_size=32,
sampler=val_sampler,
num_workers=2,
pin_memory=True
)

# Now run evaluation
model.eval()
with torch.no_grad():
# Your evaluation logic goes here
pass

cleanup()

Implementing the Evaluation Loop

Now let's implement a complete evaluation function with proper metric aggregation:

python
def distributed_evaluate(rank, world_size, model, val_dataset, criterion):
# Setup and data loading (as above)
setup(rank, world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)

val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler, num_workers=2)

# Evaluation loop
model.eval()
val_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)

# Calculate loss
loss = criterion(outputs, targets)
val_loss += loss.item()

# Calculate accuracy
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

# Aggregate metrics from all processes
val_loss_tensor = torch.tensor([val_loss], device=device)
correct_tensor = torch.tensor([correct], device=device)
total_tensor = torch.tensor([total], device=device)

# Sum the metrics across all processes
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)

# Calculate the global metrics
val_loss = val_loss_tensor.item() / world_size # Average loss across processes
accuracy = 100.0 * correct_tensor.item() / total_tensor.item()

# Print results only on the main process
if rank == 0:
print(f"Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%")

cleanup()
return val_loss, accuracy

Running the Distributed Evaluation

To run our distributed evaluation, we need to launch it using PyTorch's multiprocessing:

python
def main():
# Create model and dataset
model = YourModel()
val_dataset = YourDataset()
criterion = nn.CrossEntropyLoss()

# Determine world size and run distributed evaluation
world_size = torch.cuda.device_count() if torch.cuda.is_available() else 1

if world_size > 1:
print(f"Using {world_size} GPUs for distributed evaluation")
mp.spawn(
distributed_evaluate,
args=(world_size, model, val_dataset, criterion),
nprocs=world_size,
join=True
)
else:
# Fall back to single-device evaluation
print("Using single device for evaluation")
distributed_evaluate(0, 1, model, val_dataset, criterion)

if __name__ == "__main__":
main()

Advanced Techniques for Distributed Evaluation

1. Handling Custom Metrics

For custom or complex metrics like Mean Average Precision (MAP) or F1-score, you'll need to gather all predictions and targets on a single process:

python
def distributed_evaluate_with_custom_metrics(rank, world_size, model, val_dataset):
# Setup as before
# ...

all_predictions = []
all_targets = []

with torch.no_grad():
for inputs, targets in val_loader:
inputs = inputs.to(device)
outputs = model(inputs)
predictions = torch.softmax(outputs, dim=1)

# Store predictions and targets
all_predictions.append(predictions.cpu())
all_targets.append(targets)

# Concatenate all predictions and targets from this process
all_predictions = torch.cat(all_predictions)
all_targets = torch.cat(all_targets)

# Gather predictions and targets from all processes to rank 0
gathered_predictions = [torch.zeros_like(all_predictions) for _ in range(world_size)]
gathered_targets = [torch.zeros_like(all_targets) for _ in range(world_size)]

dist.all_gather(gathered_predictions, all_predictions)
dist.all_gather(gathered_targets, all_targets)

# Process 0 calculates the metrics
if rank == 0:
# Concatenate all gathered tensors
full_predictions = torch.cat(gathered_predictions)
full_targets = torch.cat(gathered_targets)

# Now calculate advanced metrics
precision = precision_score(full_targets.numpy(), full_predictions.argmax(dim=1).numpy(), average='macro')
recall = recall_score(full_targets.numpy(), full_predictions.argmax(dim=1).numpy(), average='macro')
f1 = f1_score(full_targets.numpy(), full_predictions.argmax(dim=1).numpy(), average='macro')

print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")

cleanup()

2. Efficient Evaluation with Checkpointing

When evaluating large models, loading the model on each process can be memory-intensive. Here's a more efficient way using checkpoints:

python
def evaluate_from_checkpoint(rank, world_size, checkpoint_path, model_class, val_dataset):
setup(rank, world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")

# Load model from checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

# Wrap with DDP for consistent behavior with training
model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)

# Evaluation as before
# ...

cleanup()

3. Zero Redundancy Optimizer Integration

If you're using ZeRO (Zero Redundancy Optimizer) for training, you can integrate it with evaluation:

python
import torch
from torch.distributed.optim import ZeroRedundancyOptimizer

def zero_evaluate(rank, world_size, model, val_dataset, optimizer_class=torch.optim.Adam):
setup(rank, world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")

model = model.to(device)

# Initialize ZeRO optimizer (needed to properly shard the model)
optimizer = ZeroRedundancyOptimizer(
model.parameters(),
optimizer_class=optimizer_class,
lr=0.001 # Dummy value, not used for evaluation
)

# Create distributed sampler for validation data
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)
val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler)

# Evaluation with ZeRO
model.eval()
with torch.no_grad():
# Evaluation logic as before
# ...

cleanup()

Real-World Example: Evaluating BERT on SQuAD Dataset

Let's implement distributed evaluation for a BERT model on the Stanford Question Answering Dataset (SQuAD):

python
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from transformers import BertForQuestionAnswering, BertTokenizer, squad_metrics

def evaluate_bert_squad(rank, world_size, checkpoint_path):
# Initialize distributed environment
setup(rank, world_size)
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")

# Load model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')

# Load your fine-tuned model checkpoint if available
if checkpoint_path:
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

model = model.to(device)
model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None)

# Load SQuAD validation dataset
from datasets import load_dataset
squad_val = load_dataset('squad', split='validation')

# Preprocess the dataset
def preprocess_squad(examples):
questions = [q.strip() for q in examples["question"]]
inputs = tokenizer(
questions,
examples["context"],
max_length=384,
truncation="only_second",
stride=128,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)

offset_mapping = inputs.pop("offset_mapping")
sample_map = inputs.pop("overflow_to_sample_mapping")
start_positions = []
end_positions = []

for i, offset in enumerate(offset_mapping):
sample_idx = sample_map[i]
answer = examples["answers"][sample_idx]
start_char = answer["answer_start"][0]
end_char = answer["answer_start"][0] + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)

# Find start and end token positions
start_token = end_token = None
for idx, (o, s) in enumerate(zip(offset, sequence_ids)):
if s != 1:
continue
if start_token is None and o[0] <= start_char < o[1]:
start_token = idx
if end_token is None and o[0] <= end_char <= o[1]:
end_token = idx

if start_token is None:
start_token = 0
if end_token is None:
end_token = 0

start_positions.append(start_token)
end_positions.append(end_token)

inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs

processed_squad = squad_val.map(
preprocess_squad,
batched=True,
remove_columns=squad_val.column_names,
)

# Create the DataLoader
val_sampler = DistributedSampler(
processed_squad,
num_replicas=world_size,
rank=rank,
shuffle=False
)

val_loader = DataLoader(
processed_squad,
batch_size=16,
sampler=val_sampler,
collate_fn=lambda batch: {k: torch.stack([torch.tensor(x[k]) for x in batch]) for k in batch[0]},
)

# Evaluation loop
model.eval()
all_start_logits = []
all_end_logits = []

with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)

outputs = model(input_ids=input_ids, attention_mask=attention_mask)

start_logits = outputs.start_logits
end_logits = outputs.end_logits

all_start_logits.append(start_logits.cpu())
all_end_logits.append(end_logits.cpu())

# Concatenate all logits
all_start_logits = torch.cat(all_start_logits, dim=0)
all_end_logits = torch.cat(all_end_logits, dim=0)

# Gather results from all processes
gathered_start_logits = [torch.zeros_like(all_start_logits) for _ in range(world_size)]
gathered_end_logits = [torch.zeros_like(all_end_logits) for _ in range(world_size)]

dist.all_gather(gathered_start_logits, all_start_logits)
dist.all_gather(gathered_end_logits, all_end_logits)

# Now process 0 computes the final metrics
if rank == 0:
full_start_logits = torch.cat(gathered_start_logits)
full_end_logits = torch.cat(gathered_end_logits)

# Convert to predictions
predictions = []
for i in range(len(full_start_logits)):
start_idx = torch.argmax(full_start_logits[i])
end_idx = torch.argmax(full_end_logits[i])
predictions.append((start_idx.item(), end_idx.item()))

# Calculate F1 and Exact Match scores
# Note: This is a simplified version - real implementation would need
# to map these predictions back to the original context
print("Evaluation complete! Process predictions to get F1 and EM scores.")

cleanup()

Best Practices for Distributed Evaluation

  1. Keep Determinism:

    • Set seeds in each process to ensure reproducibility
    • Use shuffle=False in the DistributedSampler
  2. Memory Efficiency:

    • Clear GPU memory between batches for large evaluations
    • Use mixed precision when applicable
    python
    # Example of memory-efficient evaluation
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=use_amp):
    for batch in val_loader:
    outputs = model(batch)
    # Process outputs
    torch.cuda.empty_cache() # Clear cache if needed
  3. Synchronize Metrics Properly:

    • Use all_reduce to aggregate metrics across processes
    • Be careful with reduced metrics - simple averaging isn't always correct
  4. Handle Edge Cases:

    • Different processes may have different batch counts if the dataset size isn't divisible by world_size
    • Consider weighted averaging for metrics when different processes have different dataset sizes

Summary

Distributed evaluation is a crucial part of the deep learning workflow, allowing efficient validation of models trained in distributed environments. In this guide, we've covered:

  • Basic setup for distributed evaluation
  • Implementing proper metric aggregation across processes
  • Advanced techniques for handling custom metrics and large models
  • Real-world examples with complex models like BERT
  • Best practices to ensure correct and efficient evaluation

By applying these techniques, you can significantly speed up your model validation process while maintaining accuracy in your metrics.

Additional Resources

Exercises

  1. Implement distributed evaluation for an image classification model using ResNet on ImageNet.
  2. Modify the BERT evaluation example to compute and report F1 and Exact Match scores correctly.
  3. Implement distributed evaluation with early stopping based on validation metrics.
  4. Create a function that can evaluate a model on multiple GPUs but on a single node without using mp.spawn().
  5. Extend the distributed evaluation framework to log metrics to a monitoring system like TensorBoard or Weights & Biases.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)