PyTorch Model Debugging
Deep learning models can be notoriously difficult to debug. When your PyTorch model isn't performing as expected, having a systematic debugging approach can save you countless hours of frustration. This guide will walk you through various techniques to effectively debug your PyTorch models.
Introduction
Model debugging in PyTorch involves identifying and fixing issues that prevent your neural networks from learning properly or producing the expected outputs. Common problems include:
- Models that don't learn (loss doesn't decrease)
- Unexpected outputs or predictions
- Runtime errors during training or inference
- Memory issues with large models
- Numerical instabilities (NaN/Infinity values)
In this guide, we'll explore a structured approach to debug these problems and tools PyTorch provides to help along the way.
Basic Model Debugging Techniques
1. Verify Your Input Data
One of the first things to check is whether your input data is correct:
def inspect_data(dataloader):
batch = next(iter(dataloader))
inputs, labels = batch
print(f"Input shape: {inputs.shape}")
print(f"Input dtype: {inputs.dtype}")
print(f"Input range: [{inputs.min().item()}, {inputs.max().item()}]")
print(f"Labels shape: {labels.shape}")
print(f"Labels unique values: {torch.unique(labels)}")
# Visualize a sample if it's an image
if len(inputs.shape) == 4: # BCHW format
plt.imshow(inputs[0].permute(1, 2, 0).cpu().numpy())
plt.title(f"Label: {labels[0].item()}")
plt.show()
Example Output:
Input shape: torch.Size([32, 3, 224, 224])
Input dtype: torch.float32
Input range: [0.0, 1.0]
Labels shape: torch.Size([32])
Labels unique values: tensor([0, 1, 2, 3, 4])
2. Check Model Architecture
Ensure your model architecture is correctly defined:
def inspect_model(model):
# Print model architecture
print(model)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Verify forward pass with random input
device = next(model.parameters()).device
random_input = torch.randn(1, 3, 224, 224).to(device)
try:
output = model(random_input)
print(f"Forward pass output shape: {output.shape}")
print(f"Output range: [{output.min().item():.4f}, {output.max().item():.4f}]")
except Exception as e:
print(f"Forward pass failed with error: {str(e)}")
3. Monitor Gradients
Tracking gradients can help identify issues like vanishing or exploding gradients:
def register_gradient_hooks(model):
gradient_stats = {}
for name, param in model.named_parameters():
if param.requires_grad:
def hook_factory(name):
def hook(grad):
if name not in gradient_stats:
gradient_stats[name] = {"min": [], "max": [], "mean": [], "norm": []}
gradient_stats[name]["min"].append(grad.min().item())
gradient_stats[name]["max"].append(grad.max().item())
gradient_stats[name]["mean"].append(grad.mean().item())
gradient_stats[name]["norm"].append(grad.norm().item())
return hook
param.register_hook(hook_factory(name))
return gradient_stats
# Usage during training
model = YourModel()
gradient_stats = register_gradient_hooks(model)
# After some training iterations
def plot_gradient_stats(gradient_stats, layer_name):
stats = gradient_stats[layer_name]
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(stats["min"])
plt.title(f"{layer_name} - Gradient Min")
plt.subplot(2, 2, 2)
plt.plot(stats["max"])
plt.title(f"{layer_name} - Gradient Max")
plt.subplot(2, 2, 3)
plt.plot(stats["mean"])
plt.title(f"{layer_name} - Gradient Mean")
plt.subplot(2, 2, 4)
plt.plot(stats["norm"])
plt.title(f"{layer_name} - Gradient Norm")
plt.tight_layout()
plt.show()
Using PyTorch's Built-in Debugging Tools
1. Autograd Profiler
The autograd profiler helps you understand the time and memory consumption of your operations:
import torch.autograd.profiler as profiler
model = YourModel()
input = torch.randn(32, 3, 224, 224)
with profiler.profile(record_shapes=True) as prof:
with profiler.record_function("model_inference"):
output = model(input)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
Example Output:
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
model_inference 0.12% 235.000us 100.00% 197.800ms 197.800ms 1
aten::conv2d 0.02% 30.000us 93.72% 185.438ms 23.180ms 8
aten::convolution_forward 0.01% 20.000us 93.70% 185.408ms 23.176ms 8
aten::_convolution 0.02% 38.000us 93.69% 185.388ms 23.173ms 8
...
2. Anomaly Detection
PyTorch has a built-in anomaly detection feature that can help catch issues like NaN gradients early:
# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)
def train_with_anomaly_detection(model, dataloader, criterion, optimizer, epochs=1):
for epoch in range(epochs):
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
try:
loss.backward()
optimizer.step()
except RuntimeError as e:
print(f"Caught backward error: {e}")
# Additional debugging information
print(f"Loss value: {loss.item()}")
# Check parameters for NaN values
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(f"NaN detected in parameters: {name}")
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN detected in gradients: {name}")
break
Advanced Debugging Techniques
1. Visualizing Activations
Visualize intermediate activations to understand what your network is learning:
class FeatureExtractor:
def __init__(self, model, layers):
self.model = model
self.layers = layers
self.features = {layer: None for layer in layers}
# Register hooks
for layer_name in self.layers:
layer = dict([*self.model.named_modules()])[layer_name]
layer.register_forward_hook(self.get_hook(layer_name))
def get_hook(self, layer_name):
def hook(module, input, output):
self.features[layer_name] = output
return hook
def extract_features(self, x):
_ = self.model(x)
return self.features
# Usage example
model = torchvision.models.resnet18(pretrained=True)
extractor = FeatureExtractor(model, ['layer1', 'layer2', 'layer3'])
# Get a sample image
input_image = torch.randn(1, 3, 224, 224)
features = extractor.extract_features(input_image)
# Visualize activations
def visualize_feature_maps(features, layer_name, num_filters=8):
feature_map = features[layer_name][0].detach().cpu().numpy()
fig, axes = plt.subplots(1, min(num_filters, feature_map.shape[0]), figsize=(15, 5))
for i, ax in enumerate(axes):
if i < feature_map.shape[0]:
ax.imshow(feature_map[i], cmap='viridis')
ax.set_title(f'Filter {i}')
ax.axis('off')
plt.suptitle(f'Feature Maps for {layer_name}')
plt.show()
visualize_feature_maps(features, 'layer1', num_filters=8)
2. Weight and Bias Distribution Analysis
Monitor the distribution of weights and biases to catch issues like mode collapse:
def analyze_parameters(model, epoch=0):
fig = plt.figure(figsize=(20, 10))
plot_index = 1
for name, param in model.named_parameters():
if 'weight' in name or 'bias' in name:
plt.subplot(3, 4, plot_index)
# Convert to numpy and flatten
values = param.data.cpu().numpy().flatten()
# Plot histogram
plt.hist(values, bins=50)
plt.title(f"{name} - Epoch {epoch}")
plt.xlabel("Value")
plt.ylabel("Count")
# Calculate and display statistics
mean = values.mean()
std = values.std()
plt.axvline(mean, color='r', linestyle='dashed', linewidth=1)
plt.text(0.05, 0.95, f"μ={mean:.4f}, σ={std:.4f}",
transform=plt.gca().transAxes, fontsize=10,
verticalalignment='top')
plot_index += 1
if plot_index > 12: # Limit to 12 subplots
break
plt.tight_layout()
plt.show()
Practical Example: Debugging a CNN Model
Let's put everything together in a real-world example. We'll create a simple CNN for image classification and debug common issues.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Define a simple CNN with potential issues
class BuggyModel(nn.Module):
def __init__(self):
super(BuggyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 8 * 8, 512) # Wrong input size!
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
# Missing flatten operation
x = self.fc1(x) # This will cause a shape error
x = torch.relu(x)
x = self.fc2(x)
return x
# Set up a simple dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=2)
# Initialize model and optimizer
model = BuggyModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Debug step 1: Inspect data
inspect_data(trainloader)
# Debug step 2: Check model architecture
try:
inspect_model(model)
except Exception as e:
print(f"Model inspection failed: {str(e)}")
# Fix the model
class FixedModel(nn.Module):
def __init__(self):
super(FixedModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
# Let's calculate the correct input size
# CIFAR10 is 32x32, after two 2x2 pooling layers: 8x8
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
# Add flatten operation
x = x.view(-1, 64 * 8 * 8)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = FixedModel()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
inspect_model(model)
# Debug step 3: Track gradients
gradient_stats = register_gradient_hooks(model)
# Debug step 4: Train with anomaly detection
torch.autograd.set_detect_anomaly(True)
# Modified training loop with debugging
def debug_training(model, trainloader, criterion, optimizer, num_batches=10):
model.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
if i >= num_batches:
break
inputs, labels = data
optimizer.zero_grad()
# Forward pass with timing
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
outputs = model(inputs)
end_time.record()
torch.cuda.synchronize()
print(f"Batch {i} - Forward pass time: {start_time.elapsed_time(end_time):.2f} ms")
# Check for NaN in outputs
if torch.isnan(outputs).any():
print(f"NaN detected in outputs at batch {i}")
break
loss = criterion(outputs, labels)
# Check if loss is valid
if torch.isnan(loss).any():
print(f"NaN detected in loss at batch {i}")
break
print(f"Batch {i} - Loss: {loss.item():.4f}")
# Backward pass
loss.backward()
# Check for NaN in gradients
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN detected in gradient for {name} at batch {i}")
return
optimizer.step()
running_loss += loss.item()
print(f"Average loss: {running_loss / num_batches:.4f}")
# Run the debug training
debug_training(model, trainloader, criterion, optimizer)
# Visualize gradient statistics after training
plot_gradient_stats(gradient_stats, 'conv1.weight')
# Analyze weight distributions
analyze_parameters(model, epoch=0)
Common Issues and Solutions
Problem | Symptoms | Possible Solutions |
---|---|---|
Vanishing Gradients | Loss plateaus, weights in early layers don't update | Use skip connections, batch normalization, proper weight initialization |
Exploding Gradients | NaN values in loss or weights, model diverges | Gradient clipping, reduce learning rate, weight decay |
Overfitting | Training loss keeps decreasing, validation loss increases | Data augmentation, dropout, regularization, early stopping |
Underfitting | Both training and validation loss stay high | Increase model capacity, train longer, reduce regularization |
Memory Issues | CUDA out of memory errors | Reduce batch size, use gradient accumulation, model parallelism |
Systematic Debugging Approach
When debugging PyTorch models, follow these steps:
- Start small: Use a small subset of data and a simple model
- Verify data: Check shapes, ranges, and correctness of your data pipeline
- Validate model: Ensure forward/backward passes work with random data
- Monitor metrics: Track loss, accuracy, gradients during training
- Visualize: Look at activations, weight distributions, and predictions
- Log everything: Save checkpoints and logs for reference
- Isolate issues: Test components individually to find the problem
- Simplify: Remove parts of the model until the issue disappears
Summary
Debugging PyTorch models requires a systematic approach and the right tools. In this guide, we've covered:
- Basic debugging techniques for data and model verification
- How to monitor and visualize gradients, activations, and weights
- Using PyTorch's built-in debugging tools like autograd profiler and anomaly detection
- A practical example of identifying and fixing common issues
- A systematic approach to model debugging
By applying these techniques, you can save valuable time and ensure your models are working correctly.
Additional Resources
- PyTorch Documentation on Autograd
- PyTorch Forums for specific issues
- PyTorch Profiler for performance debugging
- TensorBoard integration with PyTorch for visualization
Exercises
- Debug a model that's suffering from vanishing gradients using the techniques described above
- Implement a custom hook to track the mean and standard deviation of activations in each layer
- Create a function that automatically checks for common issues in a PyTorch model
- Use PyTorch's profiler to identify the bottlenecks in a complex model and optimize them
- Build a dashboard using TensorBoard that monitors all the key debugging metrics in real-time
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)