Skip to main content

PyTorch Model Debugging

Deep learning models can be notoriously difficult to debug. When your PyTorch model isn't performing as expected, having a systematic debugging approach can save you countless hours of frustration. This guide will walk you through various techniques to effectively debug your PyTorch models.

Introduction

Model debugging in PyTorch involves identifying and fixing issues that prevent your neural networks from learning properly or producing the expected outputs. Common problems include:

  • Models that don't learn (loss doesn't decrease)
  • Unexpected outputs or predictions
  • Runtime errors during training or inference
  • Memory issues with large models
  • Numerical instabilities (NaN/Infinity values)

In this guide, we'll explore a structured approach to debug these problems and tools PyTorch provides to help along the way.

Basic Model Debugging Techniques

1. Verify Your Input Data

One of the first things to check is whether your input data is correct:

python
def inspect_data(dataloader):
batch = next(iter(dataloader))
inputs, labels = batch

print(f"Input shape: {inputs.shape}")
print(f"Input dtype: {inputs.dtype}")
print(f"Input range: [{inputs.min().item()}, {inputs.max().item()}]")
print(f"Labels shape: {labels.shape}")
print(f"Labels unique values: {torch.unique(labels)}")

# Visualize a sample if it's an image
if len(inputs.shape) == 4: # BCHW format
plt.imshow(inputs[0].permute(1, 2, 0).cpu().numpy())
plt.title(f"Label: {labels[0].item()}")
plt.show()

Example Output:

Input shape: torch.Size([32, 3, 224, 224])
Input dtype: torch.float32
Input range: [0.0, 1.0]
Labels shape: torch.Size([32])
Labels unique values: tensor([0, 1, 2, 3, 4])

2. Check Model Architecture

Ensure your model architecture is correctly defined:

python
def inspect_model(model):
# Print model architecture
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Verify forward pass with random input
device = next(model.parameters()).device
random_input = torch.randn(1, 3, 224, 224).to(device)

try:
output = model(random_input)
print(f"Forward pass output shape: {output.shape}")
print(f"Output range: [{output.min().item():.4f}, {output.max().item():.4f}]")
except Exception as e:
print(f"Forward pass failed with error: {str(e)}")

3. Monitor Gradients

Tracking gradients can help identify issues like vanishing or exploding gradients:

python
def register_gradient_hooks(model):
gradient_stats = {}

for name, param in model.named_parameters():
if param.requires_grad:
def hook_factory(name):
def hook(grad):
if name not in gradient_stats:
gradient_stats[name] = {"min": [], "max": [], "mean": [], "norm": []}

gradient_stats[name]["min"].append(grad.min().item())
gradient_stats[name]["max"].append(grad.max().item())
gradient_stats[name]["mean"].append(grad.mean().item())
gradient_stats[name]["norm"].append(grad.norm().item())
return hook

param.register_hook(hook_factory(name))

return gradient_stats

# Usage during training
model = YourModel()
gradient_stats = register_gradient_hooks(model)

# After some training iterations
def plot_gradient_stats(gradient_stats, layer_name):
stats = gradient_stats[layer_name]
plt.figure(figsize=(12, 8))

plt.subplot(2, 2, 1)
plt.plot(stats["min"])
plt.title(f"{layer_name} - Gradient Min")

plt.subplot(2, 2, 2)
plt.plot(stats["max"])
plt.title(f"{layer_name} - Gradient Max")

plt.subplot(2, 2, 3)
plt.plot(stats["mean"])
plt.title(f"{layer_name} - Gradient Mean")

plt.subplot(2, 2, 4)
plt.plot(stats["norm"])
plt.title(f"{layer_name} - Gradient Norm")

plt.tight_layout()
plt.show()

Using PyTorch's Built-in Debugging Tools

1. Autograd Profiler

The autograd profiler helps you understand the time and memory consumption of your operations:

python
import torch.autograd.profiler as profiler

model = YourModel()
input = torch.randn(32, 3, 224, 224)

with profiler.profile(record_shapes=True) as prof:
with profiler.record_function("model_inference"):
output = model(input)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

Example Output:

---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
model_inference 0.12% 235.000us 100.00% 197.800ms 197.800ms 1
aten::conv2d 0.02% 30.000us 93.72% 185.438ms 23.180ms 8
aten::convolution_forward 0.01% 20.000us 93.70% 185.408ms 23.176ms 8
aten::_convolution 0.02% 38.000us 93.69% 185.388ms 23.173ms 8
...

2. Anomaly Detection

PyTorch has a built-in anomaly detection feature that can help catch issues like NaN gradients early:

python
# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)

def train_with_anomaly_detection(model, dataloader, criterion, optimizer, epochs=1):
for epoch in range(epochs):
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)

try:
loss.backward()
optimizer.step()
except RuntimeError as e:
print(f"Caught backward error: {e}")
# Additional debugging information
print(f"Loss value: {loss.item()}")

# Check parameters for NaN values
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(f"NaN detected in parameters: {name}")

if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN detected in gradients: {name}")

break

Advanced Debugging Techniques

1. Visualizing Activations

Visualize intermediate activations to understand what your network is learning:

python
class FeatureExtractor:
def __init__(self, model, layers):
self.model = model
self.layers = layers
self.features = {layer: None for layer in layers}

# Register hooks
for layer_name in self.layers:
layer = dict([*self.model.named_modules()])[layer_name]
layer.register_forward_hook(self.get_hook(layer_name))

def get_hook(self, layer_name):
def hook(module, input, output):
self.features[layer_name] = output
return hook

def extract_features(self, x):
_ = self.model(x)
return self.features

# Usage example
model = torchvision.models.resnet18(pretrained=True)
extractor = FeatureExtractor(model, ['layer1', 'layer2', 'layer3'])

# Get a sample image
input_image = torch.randn(1, 3, 224, 224)
features = extractor.extract_features(input_image)

# Visualize activations
def visualize_feature_maps(features, layer_name, num_filters=8):
feature_map = features[layer_name][0].detach().cpu().numpy()

fig, axes = plt.subplots(1, min(num_filters, feature_map.shape[0]), figsize=(15, 5))
for i, ax in enumerate(axes):
if i < feature_map.shape[0]:
ax.imshow(feature_map[i], cmap='viridis')
ax.set_title(f'Filter {i}')
ax.axis('off')
plt.suptitle(f'Feature Maps for {layer_name}')
plt.show()

visualize_feature_maps(features, 'layer1', num_filters=8)

2. Weight and Bias Distribution Analysis

Monitor the distribution of weights and biases to catch issues like mode collapse:

python
def analyze_parameters(model, epoch=0):
fig = plt.figure(figsize=(20, 10))
plot_index = 1

for name, param in model.named_parameters():
if 'weight' in name or 'bias' in name:
plt.subplot(3, 4, plot_index)

# Convert to numpy and flatten
values = param.data.cpu().numpy().flatten()

# Plot histogram
plt.hist(values, bins=50)
plt.title(f"{name} - Epoch {epoch}")
plt.xlabel("Value")
plt.ylabel("Count")

# Calculate and display statistics
mean = values.mean()
std = values.std()
plt.axvline(mean, color='r', linestyle='dashed', linewidth=1)
plt.text(0.05, 0.95, f"μ={mean:.4f}, σ={std:.4f}",
transform=plt.gca().transAxes, fontsize=10,
verticalalignment='top')

plot_index += 1
if plot_index > 12: # Limit to 12 subplots
break

plt.tight_layout()
plt.show()

Practical Example: Debugging a CNN Model

Let's put everything together in a real-world example. We'll create a simple CNN for image classification and debug common issues.

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Define a simple CNN with potential issues
class BuggyModel(nn.Module):
def __init__(self):
super(BuggyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 8 * 8, 512) # Wrong input size!
self.fc2 = nn.Linear(512, 10)

def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
# Missing flatten operation
x = self.fc1(x) # This will cause a shape error
x = torch.relu(x)
x = self.fc2(x)
return x

# Set up a simple dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=2)

# Initialize model and optimizer
model = BuggyModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Debug step 1: Inspect data
inspect_data(trainloader)

# Debug step 2: Check model architecture
try:
inspect_model(model)
except Exception as e:
print(f"Model inspection failed: {str(e)}")

# Fix the model
class FixedModel(nn.Module):
def __init__(self):
super(FixedModel, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
# Let's calculate the correct input size
# CIFAR10 is 32x32, after two 2x2 pooling layers: 8x8
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)

def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
# Add flatten operation
x = x.view(-1, 64 * 8 * 8)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

model = FixedModel()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
inspect_model(model)

# Debug step 3: Track gradients
gradient_stats = register_gradient_hooks(model)

# Debug step 4: Train with anomaly detection
torch.autograd.set_detect_anomaly(True)

# Modified training loop with debugging
def debug_training(model, trainloader, criterion, optimizer, num_batches=10):
model.train()
running_loss = 0.0

for i, data in enumerate(trainloader, 0):
if i >= num_batches:
break

inputs, labels = data
optimizer.zero_grad()

# Forward pass with timing
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)

start_time.record()
outputs = model(inputs)
end_time.record()
torch.cuda.synchronize()

print(f"Batch {i} - Forward pass time: {start_time.elapsed_time(end_time):.2f} ms")

# Check for NaN in outputs
if torch.isnan(outputs).any():
print(f"NaN detected in outputs at batch {i}")
break

loss = criterion(outputs, labels)

# Check if loss is valid
if torch.isnan(loss).any():
print(f"NaN detected in loss at batch {i}")
break

print(f"Batch {i} - Loss: {loss.item():.4f}")

# Backward pass
loss.backward()

# Check for NaN in gradients
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN detected in gradient for {name} at batch {i}")
return

optimizer.step()
running_loss += loss.item()

print(f"Average loss: {running_loss / num_batches:.4f}")

# Run the debug training
debug_training(model, trainloader, criterion, optimizer)

# Visualize gradient statistics after training
plot_gradient_stats(gradient_stats, 'conv1.weight')

# Analyze weight distributions
analyze_parameters(model, epoch=0)

Common Issues and Solutions

ProblemSymptomsPossible Solutions
Vanishing GradientsLoss plateaus, weights in early layers don't updateUse skip connections, batch normalization, proper weight initialization
Exploding GradientsNaN values in loss or weights, model divergesGradient clipping, reduce learning rate, weight decay
OverfittingTraining loss keeps decreasing, validation loss increasesData augmentation, dropout, regularization, early stopping
UnderfittingBoth training and validation loss stay highIncrease model capacity, train longer, reduce regularization
Memory IssuesCUDA out of memory errorsReduce batch size, use gradient accumulation, model parallelism

Systematic Debugging Approach

When debugging PyTorch models, follow these steps:

  1. Start small: Use a small subset of data and a simple model
  2. Verify data: Check shapes, ranges, and correctness of your data pipeline
  3. Validate model: Ensure forward/backward passes work with random data
  4. Monitor metrics: Track loss, accuracy, gradients during training
  5. Visualize: Look at activations, weight distributions, and predictions
  6. Log everything: Save checkpoints and logs for reference
  7. Isolate issues: Test components individually to find the problem
  8. Simplify: Remove parts of the model until the issue disappears

Summary

Debugging PyTorch models requires a systematic approach and the right tools. In this guide, we've covered:

  • Basic debugging techniques for data and model verification
  • How to monitor and visualize gradients, activations, and weights
  • Using PyTorch's built-in debugging tools like autograd profiler and anomaly detection
  • A practical example of identifying and fixing common issues
  • A systematic approach to model debugging

By applying these techniques, you can save valuable time and ensure your models are working correctly.

Additional Resources

Exercises

  1. Debug a model that's suffering from vanishing gradients using the techniques described above
  2. Implement a custom hook to track the mean and standard deviation of activations in each layer
  3. Create a function that automatically checks for common issues in a PyTorch model
  4. Use PyTorch's profiler to identify the bottlenecks in a complex model and optimize them
  5. Build a dashboard using TensorBoard that monitors all the key debugging metrics in real-time


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)