PyTorch Model Debugging
Deep learning models can be notoriously difficult to debug. When your PyTorch model isn't performing as expected, having a systematic debugging approach can save you countless hours of frustration. This guide will walk you through various techniques to effectively debug your PyTorch models.
Introduction
Model debugging in PyTorch involves identifying and fixing issues that prevent your neural networks from learning properly or producing the expected outputs. Common problems include:
- Models that don't learn (loss doesn't decrease)
- Unexpected outputs or predictions
- Runtime errors during training or inference
- Memory issues with large models
- Numerical instabilities (NaN/Infinity values)
In this guide, we'll explore a structured approach to debug these problems and tools PyTorch provides to help along the way.
Basic Model Debugging Techniques
1. Verify Your Input Data
One of the first things to check is whether your input data is correct:
def inspect_data(dataloader):
    batch = next(iter(dataloader))
    inputs, labels = batch
    
    print(f"Input shape: {inputs.shape}")
    print(f"Input dtype: {inputs.dtype}")
    print(f"Input range: [{inputs.min().item()}, {inputs.max().item()}]")
    print(f"Labels shape: {labels.shape}")
    print(f"Labels unique values: {torch.unique(labels)}")
    
    # Visualize a sample if it's an image
    if len(inputs.shape) == 4:  # BCHW format
        plt.imshow(inputs[0].permute(1, 2, 0).cpu().numpy())
        plt.title(f"Label: {labels[0].item()}")
        plt.show()
Example Output:
Input shape: torch.Size([32, 3, 224, 224])
Input dtype: torch.float32
Input range: [0.0, 1.0]
Labels shape: torch.Size([32])
Labels unique values: tensor([0, 1, 2, 3, 4])
2. Check Model Architecture
Ensure your model architecture is correctly defined:
def inspect_model(model):
    # Print model architecture
    print(model)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Verify forward pass with random input
    device = next(model.parameters()).device
    random_input = torch.randn(1, 3, 224, 224).to(device)
    
    try:
        output = model(random_input)
        print(f"Forward pass output shape: {output.shape}")
        print(f"Output range: [{output.min().item():.4f}, {output.max().item():.4f}]")
    except Exception as e:
        print(f"Forward pass failed with error: {str(e)}")
3. Monitor Gradients
Tracking gradients can help identify issues like vanishing or exploding gradients:
def register_gradient_hooks(model):
    gradient_stats = {}
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            def hook_factory(name):
                def hook(grad):
                    if name not in gradient_stats:
                        gradient_stats[name] = {"min": [], "max": [], "mean": [], "norm": []}
                    
                    gradient_stats[name]["min"].append(grad.min().item())
                    gradient_stats[name]["max"].append(grad.max().item())
                    gradient_stats[name]["mean"].append(grad.mean().item())
                    gradient_stats[name]["norm"].append(grad.norm().item())
                return hook
            
            param.register_hook(hook_factory(name))
    
    return gradient_stats
# Usage during training
model = YourModel()
gradient_stats = register_gradient_hooks(model)
# After some training iterations
def plot_gradient_stats(gradient_stats, layer_name):
    stats = gradient_stats[layer_name]
    plt.figure(figsize=(12, 8))
    
    plt.subplot(2, 2, 1)
    plt.plot(stats["min"])
    plt.title(f"{layer_name} - Gradient Min")
    
    plt.subplot(2, 2, 2)
    plt.plot(stats["max"])
    plt.title(f"{layer_name} - Gradient Max")
    
    plt.subplot(2, 2, 3)
    plt.plot(stats["mean"])
    plt.title(f"{layer_name} - Gradient Mean")
    
    plt.subplot(2, 2, 4)
    plt.plot(stats["norm"])
    plt.title(f"{layer_name} - Gradient Norm")
    
    plt.tight_layout()
    plt.show()
Using PyTorch's Built-in Debugging Tools
1. Autograd Profiler
The autograd profiler helps you understand the time and memory consumption of your operations:
import torch.autograd.profiler as profiler
model = YourModel()
input = torch.randn(32, 3, 224, 224)
with profiler.profile(record_shapes=True) as prof:
    with profiler.record_function("model_inference"):
        output = model(input)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
Example Output:
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                 model_inference        0.12%     235.000us       100.00%     197.800ms     197.800ms             1
                      aten::conv2d        0.02%      30.000us        93.72%     185.438ms      23.180ms             8
             aten::convolution_forward        0.01%      20.000us        93.70%     185.408ms      23.176ms             8
                 aten::_convolution        0.02%      38.000us        93.69%     185.388ms      23.173ms             8
...
2. Anomaly Detection
PyTorch has a built-in anomaly detection feature that can help catch issues like NaN gradients early:
# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)
def train_with_anomaly_detection(model, dataloader, criterion, optimizer, epochs=1):
    for epoch in range(epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            try:
                loss.backward()
                optimizer.step()
            except RuntimeError as e:
                print(f"Caught backward error: {e}")
                # Additional debugging information
                print(f"Loss value: {loss.item()}")
                
                # Check parameters for NaN values
                for name, param in model.named_parameters():
                    if torch.isnan(param).any():
                        print(f"NaN detected in parameters: {name}")
                    
                    if param.grad is not None and torch.isnan(param.grad).any():
                        print(f"NaN detected in gradients: {name}")
                
                break
Advanced Debugging Techniques
1. Visualizing Activations
Visualize intermediate activations to understand what your network is learning:
class FeatureExtractor:
    def __init__(self, model, layers):
        self.model = model
        self.layers = layers
        self.features = {layer: None for layer in layers}
        
        # Register hooks
        for layer_name in self.layers:
            layer = dict([*self.model.named_modules()])[layer_name]
            layer.register_forward_hook(self.get_hook(layer_name))
    
    def get_hook(self, layer_name):
        def hook(module, input, output):
            self.features[layer_name] = output
        return hook
    
    def extract_features(self, x):
        _ = self.model(x)
        return self.features
# Usage example
model = torchvision.models.resnet18(pretrained=True)
extractor = FeatureExtractor(model, ['layer1', 'layer2', 'layer3'])
# Get a sample image
input_image = torch.randn(1, 3, 224, 224)
features = extractor.extract_features(input_image)
# Visualize activations
def visualize_feature_maps(features, layer_name, num_filters=8):
    feature_map = features[layer_name][0].detach().cpu().numpy()
    
    fig, axes = plt.subplots(1, min(num_filters, feature_map.shape[0]), figsize=(15, 5))
    for i, ax in enumerate(axes):
        if i < feature_map.shape[0]:
            ax.imshow(feature_map[i], cmap='viridis')
            ax.set_title(f'Filter {i}')
        ax.axis('off')
    plt.suptitle(f'Feature Maps for {layer_name}')
    plt.show()
visualize_feature_maps(features, 'layer1', num_filters=8)
2. Weight and Bias Distribution Analysis
Monitor the distribution of weights and biases to catch issues like mode collapse:
def analyze_parameters(model, epoch=0):
    fig = plt.figure(figsize=(20, 10))
    plot_index = 1
    
    for name, param in model.named_parameters():
        if 'weight' in name or 'bias' in name:
            plt.subplot(3, 4, plot_index)
            
            # Convert to numpy and flatten
            values = param.data.cpu().numpy().flatten()
            
            # Plot histogram
            plt.hist(values, bins=50)
            plt.title(f"{name} - Epoch {epoch}")
            plt.xlabel("Value")
            plt.ylabel("Count")
            
            # Calculate and display statistics
            mean = values.mean()
            std = values.std()
            plt.axvline(mean, color='r', linestyle='dashed', linewidth=1)
            plt.text(0.05, 0.95, f"μ={mean:.4f}, σ={std:.4f}", 
                     transform=plt.gca().transAxes, fontsize=10,
                     verticalalignment='top')
            
            plot_index += 1
            if plot_index > 12:  # Limit to 12 subplots
                break
    
    plt.tight_layout()
    plt.show()
Practical Example: Debugging a CNN Model
Let's put everything together in a real-world example. We'll create a simple CNN for image classification and debug common issues.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Define a simple CNN with potential issues
class BuggyModel(nn.Module):
    def __init__(self):
        super(BuggyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)  # Wrong input size!
        self.fc2 = nn.Linear(512, 10)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        # Missing flatten operation
        x = self.fc1(x)  # This will cause a shape error
        x = torch.relu(x)
        x = self.fc2(x)
        return x
# Set up a simple dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                       download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                         shuffle=True, num_workers=2)
# Initialize model and optimizer
model = BuggyModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Debug step 1: Inspect data
inspect_data(trainloader)
# Debug step 2: Check model architecture
try:
    inspect_model(model)
except Exception as e:
    print(f"Model inspection failed: {str(e)}")
    
    # Fix the model
    class FixedModel(nn.Module):
        def __init__(self):
            super(FixedModel, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
            # Let's calculate the correct input size
            # CIFAR10 is 32x32, after two 2x2 pooling layers: 8x8
            self.fc1 = nn.Linear(64 * 8 * 8, 512)
            self.fc2 = nn.Linear(512, 10)
            
        def forward(self, x):
            x = self.pool(torch.relu(self.conv1(x)))
            x = self.pool(torch.relu(self.conv2(x)))
            # Add flatten operation
            x = x.view(-1, 64 * 8 * 8)
            x = torch.relu(self.fc1(x))
            x = self.fc2(x)
            return x
    
    model = FixedModel()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    inspect_model(model)
# Debug step 3: Track gradients
gradient_stats = register_gradient_hooks(model)
# Debug step 4: Train with anomaly detection
torch.autograd.set_detect_anomaly(True)
# Modified training loop with debugging
def debug_training(model, trainloader, criterion, optimizer, num_batches=10):
    model.train()
    running_loss = 0.0
    
    for i, data in enumerate(trainloader, 0):
        if i >= num_batches:
            break
            
        inputs, labels = data
        optimizer.zero_grad()
        
        # Forward pass with timing
        start_time = torch.cuda.Event(enable_timing=True)
        end_time = torch.cuda.Event(enable_timing=True)
        
        start_time.record()
        outputs = model(inputs)
        end_time.record()
        torch.cuda.synchronize()
        
        print(f"Batch {i} - Forward pass time: {start_time.elapsed_time(end_time):.2f} ms")
        
        # Check for NaN in outputs
        if torch.isnan(outputs).any():
            print(f"NaN detected in outputs at batch {i}")
            break
            
        loss = criterion(outputs, labels)
        
        # Check if loss is valid
        if torch.isnan(loss).any():
            print(f"NaN detected in loss at batch {i}")
            break
            
        print(f"Batch {i} - Loss: {loss.item():.4f}")
        
        # Backward pass
        loss.backward()
        
        # Check for NaN in gradients
        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                print(f"NaN detected in gradient for {name} at batch {i}")
                return
        
        optimizer.step()
        running_loss += loss.item()
    
    print(f"Average loss: {running_loss / num_batches:.4f}")
# Run the debug training
debug_training(model, trainloader, criterion, optimizer)
# Visualize gradient statistics after training
plot_gradient_stats(gradient_stats, 'conv1.weight')
# Analyze weight distributions
analyze_parameters(model, epoch=0)
Common Issues and Solutions
| Problem | Symptoms | Possible Solutions | 
|---|---|---|
| Vanishing Gradients | Loss plateaus, weights in early layers don't update | Use skip connections, batch normalization, proper weight initialization | 
| Exploding Gradients | NaN values in loss or weights, model diverges | Gradient clipping, reduce learning rate, weight decay | 
| Overfitting | Training loss keeps decreasing, validation loss increases | Data augmentation, dropout, regularization, early stopping | 
| Underfitting | Both training and validation loss stay high | Increase model capacity, train longer, reduce regularization | 
| Memory Issues | CUDA out of memory errors | Reduce batch size, use gradient accumulation, model parallelism | 
Systematic Debugging Approach
When debugging PyTorch models, follow these steps:
- Start small: Use a small subset of data and a simple model
- Verify data: Check shapes, ranges, and correctness of your data pipeline
- Validate model: Ensure forward/backward passes work with random data
- Monitor metrics: Track loss, accuracy, gradients during training
- Visualize: Look at activations, weight distributions, and predictions
- Log everything: Save checkpoints and logs for reference
- Isolate issues: Test components individually to find the problem
- Simplify: Remove parts of the model until the issue disappears
Summary
Debugging PyTorch models requires a systematic approach and the right tools. In this guide, we've covered:
- Basic debugging techniques for data and model verification
- How to monitor and visualize gradients, activations, and weights
- Using PyTorch's built-in debugging tools like autograd profiler and anomaly detection
- A practical example of identifying and fixing common issues
- A systematic approach to model debugging
By applying these techniques, you can save valuable time and ensure your models are working correctly.
Additional Resources
- PyTorch Documentation on Autograd
- PyTorch Forums for specific issues
- PyTorch Profiler for performance debugging
- TensorBoard integration with PyTorch for visualization
Exercises
- Debug a model that's suffering from vanishing gradients using the techniques described above
- Implement a custom hook to track the mean and standard deviation of activations in each layer
- Create a function that automatically checks for common issues in a PyTorch model
- Use PyTorch's profiler to identify the bottlenecks in a complex model and optimize them
- Build a dashboard using TensorBoard that monitors all the key debugging metrics in real-time
💡 Found a typo or mistake? Click "Edit this page" to suggest a correction. Your feedback is greatly appreciated!