Skip to main content

PyTorch Visualization Tools

Introduction

Visualizing your models, data, and training processes is a crucial aspect of debugging and understanding deep learning systems. Without proper visualization, deep learning can often feel like a mysterious black box. PyTorch offers several visualization tools that help you gain insights into your models and training progress.

In this tutorial, we'll explore various visualization techniques in PyTorch that will help you:

  • Visualize model architectures
  • Track training metrics
  • Inspect tensors and data
  • Visualize feature maps and activations
  • Debug gradient flow and parameter updates

Let's dive in and learn how to make your PyTorch projects more transparent and easier to debug!

TensorBoard Integration with PyTorch

TensorBoard is one of the most powerful visualization tools available for deep learning, and PyTorch integrates with it seamlessly through the torch.utils.tensorboard module.

Setting Up TensorBoard

First, let's install the required packages if you haven't already:

bash
pip install tensorboard

Now, let's see how to use TensorBoard with PyTorch:

python
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np

# Create a SummaryWriter instance
writer = SummaryWriter('runs/my_experiment')

# Define a simple neural network for demonstration
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

# Initialize the model
model = SimpleNN()

# Get some example data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Add model graph to tensorboard
writer.add_graph(model, images)

# Close the writer
writer.close()

After running this code, you can launch TensorBoard with:

bash
tensorboard --logdir=runs

Then, open your browser and go to http://localhost:6006 to see your model graph.

Tracking Training Metrics

TensorBoard really shines when tracking metrics over time during training. Here's how to log loss and accuracy during training:

python
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training loop
for epoch in range(2): # Just 2 epochs for demonstration
running_loss = 0.0

for i, data in enumerate(trainloader, 0):
inputs, labels = data

# Zero the parameter gradients
optimizer.zero_grad()

# Forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# Log statistics
running_loss += loss.item()
if i % 100 == 99: # Every 100 mini-batches
# Log scalar values
writer.add_scalar('training loss', running_loss / 100, epoch * len(trainloader) + i)

# Calculate accuracy
_, predicted = torch.max(outputs.data, 1)
accuracy = (predicted == labels).sum().item() / labels.size(0)
writer.add_scalar('training accuracy', accuracy, epoch * len(trainloader) + i)

running_loss = 0.0

writer.close()

Visualizing Images

TensorBoard also allows you to visualize images, which can be helpful for debugging data augmentation or model outputs:

python
# Get a batch of images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Create grid of images and add to tensorboard
img_grid = torchvision.utils.make_grid(images)
writer.add_image('Four CIFAR Images', img_grid)

# You can also log multiple images in different steps
for i in range(4):
writer.add_image(f'CIFAR Image #{i}', images[i], 0)

writer.close()

Visualizing Histograms and Distributions

TensorBoard can show histograms of model parameters and gradients:

python
for epoch in range(2):
for i, data in enumerate(trainloader, 0):
inputs, labels = data
outputs = model(inputs)
loss = criterion(outputs, labels)

# Log histograms of model parameters
for name, param in model.named_parameters():
writer.add_histogram(f'Parameters/{name}', param, epoch * len(trainloader) + i)
writer.add_histogram(f'Gradients/{name}', param.grad, epoch * len(trainloader) + i)

if i == 100: # Just log a few batches to avoid overwhelming TensorBoard
break

writer.close()

Matplotlib Integration

While TensorBoard is great for ongoing training monitoring, sometimes you need to create custom plots for presentations or papers. PyTorch works seamlessly with Matplotlib:

python
import matplotlib.pyplot as plt

# Get a single image and its label
image, label = trainset[0]
image = image / 2 + 0.5 # Unnormalize the image

# Convert tensor to numpy for plotting
npimg = image.numpy()

# Plot the image
plt.figure(figsize=(10, 8))
plt.imshow(np.transpose(npimg, (1, 2, 0))) # Convert from (C, H, W) to (H, W, C)
plt.title(f'Label: {trainset.classes[label]}')
plt.show()

Visualizing Feature Maps

Understanding what your network "sees" can be invaluable for debugging. Let's visualize the feature maps of a convolutional layer:

python
def visualize_feature_maps(model, input_tensor):
# Register a hook to capture the output of conv1
feature_maps = []

def hook_function(module, input, output):
feature_maps.append(output)

# Register the hook
hook = model.conv1.register_forward_hook(hook_function)

# Forward pass
with torch.no_grad():
output = model(input_tensor)

# Remove the hook
hook.remove()

# Get the feature maps
feature_map = feature_maps[0] # First layer feature maps

# Plot the feature maps
fig, axs = plt.subplots(1, feature_map.size(1), figsize=(15, 5))
for i in range(feature_map.size(1)): # Loop over each channel
# Extract the i-th feature map
fm = feature_map[0, i].detach().cpu().numpy()
axs[i].imshow(fm, cmap='viridis')
axs[i].axis('off')
plt.tight_layout()
plt.show()

# Get a single image and add batch dimension
image, _ = trainset[0]
input_tensor = image.unsqueeze(0) # Add batch dimension

# Visualize feature maps
visualize_feature_maps(model, input_tensor)

Loss Curve Visualization

Tracking loss over time is one of the most common visualizations in deep learning:

python
def train_and_plot(model, trainloader, epochs=5):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Lists to store metrics
train_losses = []
train_accuracies = []

for epoch in range(epochs):
running_loss = 0.0
running_acc = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()

outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()

# Calculate accuracy
_, predicted = torch.max(outputs, 1)
running_acc += (predicted == labels).sum().item() / labels.size(0)

# Record metrics every 100 batches
if i % 100 == 99:
avg_loss = running_loss / 100
avg_acc = running_acc / 100
train_losses.append(avg_loss)
train_accuracies.append(avg_acc)
running_loss = 0.0
running_acc = 0.0
print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {avg_loss:.3f}, accuracy: {avg_acc:.3f}')

# Plot loss curve
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Steps (hundreds)')
plt.ylabel('Loss')

# Plot accuracy curve
plt.subplot(1, 2, 2)
plt.plot(train_accuracies)
plt.title('Training Accuracy')
plt.xlabel('Steps (hundreds)')
plt.ylabel('Accuracy')

plt.tight_layout()
plt.show()

return model

# Train the model and visualize the results
model = train_and_plot(SimpleNN(), trainloader, epochs=2)

PyTorch Built-in Visualization Tools

PyTorch provides some built-in visualization tools for tensors and model architectures.

Visualizing Tensors with torchvision.utils

python
def visualize_batch(dataloader):
dataiter = iter(dataloader)
images, labels = next(dataiter)

# Create a grid from the images
img_grid = torchvision.utils.make_grid(images, nrow=4, padding=10)
img_grid = img_grid / 2 + 0.5 # Unnormalize

# Convert to numpy and transpose
npimg = img_grid.numpy()
plt.figure(figsize=(15, 15))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
plt.show()

# Print labels
print('Labels:', [trainset.classes[label] for label in labels])

# Visualize a batch of images
visualize_batch(trainloader)

Visualizing Model Summary

For a quick overview of your model, you can use a third-party library called torchinfo (previously known as torch_summary):

bash
pip install torchinfo
python
from torchinfo import summary

# Display model summary
summary(model, input_size=(4, 3, 32, 32))

Output:

==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
SimpleNN [4, 10] --
├─Conv2d: 1-1 [4, 6, 28, 28] 456
├─MaxPool2d: 1-2 [4, 6, 14, 14] --
├─Conv2d: 1-3 [4, 16, 10, 10] 2,416
├─MaxPool2d: 1-4 [4, 16, 5, 5] --
├─Linear: 1-5 [4, 120] 48,120
├─Linear: 1-6 [4, 84] 10,164
├─Linear: 1-7 [4, 10] 850
==========================================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
==========================================================================================

Advanced Visualization: t-SNE for Feature Visualization

t-SNE (t-Distributed Stochastic Neighbor Embedding) is a powerful technique for visualizing high-dimensional data. Let's use it to visualize features from our model:

python
from sklearn.manifold import TSNE
import seaborn as sns

def extract_features(model, dataloader, num_batches=5):
features = []
labels_list = []

# Register hook to capture features from the last layer before classification
activations = []

def hook_function(module, input, output):
activations.append(input[0])

# Register the hook on fc3 layer
hook = model.fc3.register_forward_hook(hook_function)

with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloader):
if i >= num_batches:
break

# Forward pass
outputs = model(inputs)

# Store the features and labels
batch_features = activations[-1]
features.append(batch_features)
labels_list.append(labels)

# Remove the hook
hook.remove()

# Concatenate all batches
features = torch.cat(features)
labels_list = torch.cat(labels_list)

return features.cpu().numpy(), labels_list.cpu().numpy()

def visualize_tsne(features, labels, classes):
# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42)
features_tsne = tsne.fit_transform(features)

# Create dataframe for seaborn
import pandas as pd
df = pd.DataFrame({
'x': features_tsne[:, 0],
'y': features_tsne[:, 1],
'class': [classes[label] for label in labels]
})

# Plot using seaborn
plt.figure(figsize=(10, 8))
sns.scatterplot(x='x', y='y', hue='class', data=df, palette='tab10')
plt.title('t-SNE visualization of features')
plt.show()

# Extract features from the model
features, labels = extract_features(model, trainloader)

# Visualize features using t-SNE
visualize_tsne(features, labels, trainset.classes)

Visualizing Gradient Flow with Hooks

Tracking gradients during training can help identify problems like vanishing or exploding gradients:

python
def plot_grad_flow(named_parameters):
"""Plots the gradients flowing through different layers in the net during training"""
ave_grads = []
max_grads= []
layers = []

for n, p in named_parameters:
if(p.requires_grad) and ("bias" not in n):
layers.append(n)
ave_grads.append(p.grad.abs().mean().item())
max_grads.append(p.grad.abs().max().item())

plt.figure(figsize=(10, 8))
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k")
plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
plt.xlim(left=0, right=len(ave_grads))
plt.ylim(bottom=0, top=0.02) # Adjust this based on your gradients
plt.xlabel("Layers")
plt.ylabel("Average gradient")
plt.title("Gradient flow")
plt.grid(True)
plt.legend([plt.Line2D([0], [0], color="c", lw=4),
plt.Line2D([0], [0], color="b", lw=4),
plt.Line2D([0], [0], color="k", lw=4)],
['max-gradient', 'mean-gradient', 'zero-gradient'])
plt.tight_layout()
plt.show()

# Training with gradient visualization
model = SimpleNN()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# Get a batch
dataiter = iter(trainloader)
inputs, labels = next(dataiter)

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass
loss.backward()

# Plot gradients
plot_grad_flow(model.named_parameters())

Summary

Visualization is an essential tool for understanding, debugging, and improving your PyTorch models. In this tutorial, we've covered:

  1. TensorBoard Integration for tracking metrics, visualizing model graphs, and monitoring training
  2. Matplotlib Integration for custom plots and visualizations
  3. Feature Map Visualization to understand what your convolutional layers are learning
  4. t-SNE Visualization for high-dimensional feature analysis
  5. Gradient Flow Visualization to detect training issues

By incorporating these visualization techniques into your deep learning workflow, you'll gain deeper insights into your models and be able to debug issues more effectively.

Additional Resources

Exercises

  1. Add TensorBoard visualization to a model you've built previously. Track at least three different metrics during training.
  2. Implement a function that extracts and visualizes feature maps from all convolutional layers in a network.
  3. Create a visualization that compares the feature maps of a well-trained model with those of an untrained model.
  4. Use t-SNE to visualize the features of your model at different epochs during training. Can you see how the feature space organizes as training progresses?
  5. Build a dashboard that combines TensorBoard, Matplotlib, and custom visualizations for a comprehensive view of your model's performance and behavior.

With these tools and exercises, you'll be well-equipped to gain deeper insights into your PyTorch models and debug any issues that arise during development.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)