PyTorch Visualization Tools
Introduction
Visualizing your models, data, and training processes is a crucial aspect of debugging and understanding deep learning systems. Without proper visualization, deep learning can often feel like a mysterious black box. PyTorch offers several visualization tools that help you gain insights into your models and training progress.
In this tutorial, we'll explore various visualization techniques in PyTorch that will help you:
- Visualize model architectures
- Track training metrics
- Inspect tensors and data
- Visualize feature maps and activations
- Debug gradient flow and parameter updates
Let's dive in and learn how to make your PyTorch projects more transparent and easier to debug!
TensorBoard Integration with PyTorch
TensorBoard is one of the most powerful visualization tools available for deep learning, and PyTorch integrates with it seamlessly through the torch.utils.tensorboard
module.
Setting Up TensorBoard
First, let's install the required packages if you haven't already:
pip install tensorboard
Now, let's see how to use TensorBoard with PyTorch:
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
# Create a SummaryWriter instance
writer = SummaryWriter('runs/my_experiment')
# Define a simple neural network for demonstration
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize the model
model = SimpleNN()
# Get some example data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
dataiter = iter(trainloader)
images, labels = next(dataiter)
# Add model graph to tensorboard
writer.add_graph(model, images)
# Close the writer
writer.close()
After running this code, you can launch TensorBoard with:
tensorboard --logdir=runs
Then, open your browser and go to http://localhost:6006
to see your model graph.
Tracking Training Metrics
TensorBoard really shines when tracking metrics over time during training. Here's how to log loss and accuracy during training:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Training loop
for epoch in range(2): # Just 2 epochs for demonstration
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# Zero the parameter gradients
optimizer.zero_grad()
# Forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Log statistics
running_loss += loss.item()
if i % 100 == 99: # Every 100 mini-batches
# Log scalar values
writer.add_scalar('training loss', running_loss / 100, epoch * len(trainloader) + i)
# Calculate accuracy
_, predicted = torch.max(outputs.data, 1)
accuracy = (predicted == labels).sum().item() / labels.size(0)
writer.add_scalar('training accuracy', accuracy, epoch * len(trainloader) + i)
running_loss = 0.0
writer.close()
Visualizing Images
TensorBoard also allows you to visualize images, which can be helpful for debugging data augmentation or model outputs:
# Get a batch of images
dataiter = iter(trainloader)
images, labels = next(dataiter)
# Create grid of images and add to tensorboard
img_grid = torchvision.utils.make_grid(images)
writer.add_image('Four CIFAR Images', img_grid)
# You can also log multiple images in different steps
for i in range(4):
writer.add_image(f'CIFAR Image #{i}', images[i], 0)
writer.close()
Visualizing Histograms and Distributions
TensorBoard can show histograms of model parameters and gradients:
for epoch in range(2):
for i, data in enumerate(trainloader, 0):
inputs, labels = data
outputs = model(inputs)
loss = criterion(outputs, labels)
# Log histograms of model parameters
for name, param in model.named_parameters():
writer.add_histogram(f'Parameters/{name}', param, epoch * len(trainloader) + i)
writer.add_histogram(f'Gradients/{name}', param.grad, epoch * len(trainloader) + i)
if i == 100: # Just log a few batches to avoid overwhelming TensorBoard
break
writer.close()
Matplotlib Integration
While TensorBoard is great for ongoing training monitoring, sometimes you need to create custom plots for presentations or papers. PyTorch works seamlessly with Matplotlib:
import matplotlib.pyplot as plt
# Get a single image and its label
image, label = trainset[0]
image = image / 2 + 0.5 # Unnormalize the image
# Convert tensor to numpy for plotting
npimg = image.numpy()
# Plot the image
plt.figure(figsize=(10, 8))
plt.imshow(np.transpose(npimg, (1, 2, 0))) # Convert from (C, H, W) to (H, W, C)
plt.title(f'Label: {trainset.classes[label]}')
plt.show()
Visualizing Feature Maps
Understanding what your network "sees" can be invaluable for debugging. Let's visualize the feature maps of a convolutional layer:
def visualize_feature_maps(model, input_tensor):
# Register a hook to capture the output of conv1
feature_maps = []
def hook_function(module, input, output):
feature_maps.append(output)
# Register the hook
hook = model.conv1.register_forward_hook(hook_function)
# Forward pass
with torch.no_grad():
output = model(input_tensor)
# Remove the hook
hook.remove()
# Get the feature maps
feature_map = feature_maps[0] # First layer feature maps
# Plot the feature maps
fig, axs = plt.subplots(1, feature_map.size(1), figsize=(15, 5))
for i in range(feature_map.size(1)): # Loop over each channel
# Extract the i-th feature map
fm = feature_map[0, i].detach().cpu().numpy()
axs[i].imshow(fm, cmap='viridis')
axs[i].axis('off')
plt.tight_layout()
plt.show()
# Get a single image and add batch dimension
image, _ = trainset[0]
input_tensor = image.unsqueeze(0) # Add batch dimension
# Visualize feature maps
visualize_feature_maps(model, input_tensor)
Loss Curve Visualization
Tracking loss over time is one of the most common visualizations in deep learning:
def train_and_plot(model, trainloader, epochs=5):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Lists to store metrics
train_losses = []
train_accuracies = []
for epoch in range(epochs):
running_loss = 0.0
running_acc = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# Calculate accuracy
_, predicted = torch.max(outputs, 1)
running_acc += (predicted == labels).sum().item() / labels.size(0)
# Record metrics every 100 batches
if i % 100 == 99:
avg_loss = running_loss / 100
avg_acc = running_acc / 100
train_losses.append(avg_loss)
train_accuracies.append(avg_acc)
running_loss = 0.0
running_acc = 0.0
print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {avg_loss:.3f}, accuracy: {avg_acc:.3f}')
# Plot loss curve
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Steps (hundreds)')
plt.ylabel('Loss')
# Plot accuracy curve
plt.subplot(1, 2, 2)
plt.plot(train_accuracies)
plt.title('Training Accuracy')
plt.xlabel('Steps (hundreds)')
plt.ylabel('Accuracy')
plt.tight_layout()
plt.show()
return model
# Train the model and visualize the results
model = train_and_plot(SimpleNN(), trainloader, epochs=2)
PyTorch Built-in Visualization Tools
PyTorch provides some built-in visualization tools for tensors and model architectures.
Visualizing Tensors with torchvision.utils
def visualize_batch(dataloader):
dataiter = iter(dataloader)
images, labels = next(dataiter)
# Create a grid from the images
img_grid = torchvision.utils.make_grid(images, nrow=4, padding=10)
img_grid = img_grid / 2 + 0.5 # Unnormalize
# Convert to numpy and transpose
npimg = img_grid.numpy()
plt.figure(figsize=(15, 15))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
plt.show()
# Print labels
print('Labels:', [trainset.classes[label] for label in labels])
# Visualize a batch of images
visualize_batch(trainloader)
Visualizing Model Summary
For a quick overview of your model, you can use a third-party library called torchinfo
(previously known as torch_summary
):
pip install torchinfo
from torchinfo import summary
# Display model summary
summary(model, input_size=(4, 3, 32, 32))
Output:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
SimpleNN [4, 10] --
├─Conv2d: 1-1 [4, 6, 28, 28] 456
├─MaxPool2d: 1-2 [4, 6, 14, 14] --
├─Conv2d: 1-3 [4, 16, 10, 10] 2,416
├─MaxPool2d: 1-4 [4, 16, 5, 5] --
├─Linear: 1-5 [4, 120] 48,120
├─Linear: 1-6 [4, 84] 10,164
├─Linear: 1-7 [4, 10] 850
==========================================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
==========================================================================================
Advanced Visualization: t-SNE for Feature Visualization
t-SNE (t-Distributed Stochastic Neighbor Embedding) is a powerful technique for visualizing high-dimensional data. Let's use it to visualize features from our model:
from sklearn.manifold import TSNE
import seaborn as sns
def extract_features(model, dataloader, num_batches=5):
features = []
labels_list = []
# Register hook to capture features from the last layer before classification
activations = []
def hook_function(module, input, output):
activations.append(input[0])
# Register the hook on fc3 layer
hook = model.fc3.register_forward_hook(hook_function)
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloader):
if i >= num_batches:
break
# Forward pass
outputs = model(inputs)
# Store the features and labels
batch_features = activations[-1]
features.append(batch_features)
labels_list.append(labels)
# Remove the hook
hook.remove()
# Concatenate all batches
features = torch.cat(features)
labels_list = torch.cat(labels_list)
return features.cpu().numpy(), labels_list.cpu().numpy()
def visualize_tsne(features, labels, classes):
# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42)
features_tsne = tsne.fit_transform(features)
# Create dataframe for seaborn
import pandas as pd
df = pd.DataFrame({
'x': features_tsne[:, 0],
'y': features_tsne[:, 1],
'class': [classes[label] for label in labels]
})
# Plot using seaborn
plt.figure(figsize=(10, 8))
sns.scatterplot(x='x', y='y', hue='class', data=df, palette='tab10')
plt.title('t-SNE visualization of features')
plt.show()
# Extract features from the model
features, labels = extract_features(model, trainloader)
# Visualize features using t-SNE
visualize_tsne(features, labels, trainset.classes)
Visualizing Gradient Flow with Hooks
Tracking gradients during training can help identify problems like vanishing or exploding gradients:
def plot_grad_flow(named_parameters):
"""Plots the gradients flowing through different layers in the net during training"""
ave_grads = []
max_grads= []
layers = []
for n, p in named_parameters:
if(p.requires_grad) and ("bias" not in n):
layers.append(n)
ave_grads.append(p.grad.abs().mean().item())
max_grads.append(p.grad.abs().max().item())
plt.figure(figsize=(10, 8))
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k")
plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
plt.xlim(left=0, right=len(ave_grads))
plt.ylim(bottom=0, top=0.02) # Adjust this based on your gradients
plt.xlabel("Layers")
plt.ylabel("Average gradient")
plt.title("Gradient flow")
plt.grid(True)
plt.legend([plt.Line2D([0], [0], color="c", lw=4),
plt.Line2D([0], [0], color="b", lw=4),
plt.Line2D([0], [0], color="k", lw=4)],
['max-gradient', 'mean-gradient', 'zero-gradient'])
plt.tight_layout()
plt.show()
# Training with gradient visualization
model = SimpleNN()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# Get a batch
dataiter = iter(trainloader)
inputs, labels = next(dataiter)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
# Plot gradients
plot_grad_flow(model.named_parameters())
Summary
Visualization is an essential tool for understanding, debugging, and improving your PyTorch models. In this tutorial, we've covered:
- TensorBoard Integration for tracking metrics, visualizing model graphs, and monitoring training
- Matplotlib Integration for custom plots and visualizations
- Feature Map Visualization to understand what your convolutional layers are learning
- t-SNE Visualization for high-dimensional feature analysis
- Gradient Flow Visualization to detect training issues
By incorporating these visualization techniques into your deep learning workflow, you'll gain deeper insights into your models and be able to debug issues more effectively.
Additional Resources
- PyTorch TensorBoard Documentation
- TensorBoard Tutorial by PyTorch
- Matplotlib Documentation
- torchvision.utils Documentation
Exercises
- Add TensorBoard visualization to a model you've built previously. Track at least three different metrics during training.
- Implement a function that extracts and visualizes feature maps from all convolutional layers in a network.
- Create a visualization that compares the feature maps of a well-trained model with those of an untrained model.
- Use t-SNE to visualize the features of your model at different epochs during training. Can you see how the feature space organizes as training progresses?
- Build a dashboard that combines TensorBoard, Matplotlib, and custom visualizations for a comprehensive view of your model's performance and behavior.
With these tools and exercises, you'll be well-equipped to gain deeper insights into your PyTorch models and debug any issues that arise during development.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)