Skip to main content

PyTorch Captum

Introduction

PyTorch Captum (meaning "comprehension" in Latin) is an open-source, extensible library for model interpretability built on top of PyTorch. In the era of increasingly complex machine learning models, understanding why a model makes certain decisions is becoming as important as the accuracy of those decisions. Captum provides state-of-the-art algorithms to help you understand which features your model uses for predictions and how these features interact.

Whether you're a researcher, data scientist, or ML engineer, Captum can help you:

  • Debug and improve your models
  • Ensure your models are making decisions for the right reasons
  • Comply with regulations that require explainable AI
  • Build trust with stakeholders who need to understand model decisions

Getting Started with Captum

Installation

You can install Captum using pip:

bash
pip install captum

Make sure you have PyTorch installed before installing Captum.

Basic Concepts

Before diving into code examples, let's understand some basic concepts in model interpretability:

  1. Attribution: Measuring the contribution of each input feature to the output prediction.
  2. Feature Importance: Identifying which features have the most influence on the model's predictions.
  3. Saliency Maps: Visual representations showing which parts of the input (like pixels in an image) contribute most to the output.

Basic Example: Integrated Gradients

Let's start with a simple example using one of Captum's most popular attribution methods: Integrated Gradients. This method attributes the prediction of a deep network to its input features.

First, we'll set up a simple model and prepare our data:

python
import torch
import torch.nn as nn
import captum
from captum.attr import IntegratedGradients

# Create a simple model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 1)

def forward(self, x):
return self.linear(x)

# Initialize the model
model = SimpleModel()
model.linear.weight.data = torch.tensor([[1.0, -2.0, 3.0]])
model.linear.bias.data = torch.tensor([0.0])

# Create input data
input_data = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True)

Now, let's apply Integrated Gradients to understand the model's predictions:

python
# Initialize the attribution algorithm
ig = IntegratedGradients(model)

# Compute attributions
attributions, approximation_error = ig.attribute(input_data,
target=0,
return_convergence_delta=True)

print("Input:", input_data)
print("Attributions:", attributions)
print("Approximation error:", approximation_error)

Output:

Input: tensor([[1., 2., 3.]], requires_grad=True)
Attributions: tensor([[1., -4., 9.]], grad_fn=<MulBackward0>)
Approximation error: tensor([0.])

The attributions tell us how much each input feature contributed to the model's output. In this case:

  • The first feature (1.0) contributed 1.0 to the output
  • The second feature (2.0) contributed -4.0 to the output
  • The third feature (3.0) contributed 9.0 to the output

This aligns with our model weights (1.0, -2.0, 3.0), as the attribution is essentially the product of input and weight in this simple linear model.

Visualizing Attributions in Computer Vision

Captum is particularly useful for understanding image classification models. Let's see how to visualize attributions for a pre-trained ResNet model:

python
import torch
import torchvision
import torchvision.transforms as transforms
from captum.attr import GradientShap
from captum.attr import visualization as viz
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Load pre-trained model
model = torchvision.models.resnet18(pretrained=True)
model.eval()

# Load and preprocess image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

# Replace with your image path
img = Image.open('dog.jpg')
input_tensor = transform(img).unsqueeze(0)
input_tensor.requires_grad = True

# Create a baseline (black image)
baseline = torch.zeros_like(input_tensor)

# Initialize GradientShap
gradient_shap = GradientShap(model)

# Calculate attributions
attributions = gradient_shap.attribute(input_tensor,
baselines=baseline,
target=232) # 232 is the class index for "border collie"

# Visualize the results
original_image = np.transpose(input_tensor.squeeze().cpu().detach().numpy(), (1, 2, 0))
original_image = (original_image - original_image.min()) / (original_image.max() - original_image.min())

# Convert attributions to numpy array for visualization
attributions = attributions.squeeze().cpu().detach().numpy()
attributions = np.transpose(attributions, (1, 2, 0))

# Visualize
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')

plt.subplot(1, 2, 2)
plt.imshow(viz.visualize_image_attr(attributions,
original_image,
method="heat_map",
sign="all",
show_colorbar=True))
plt.title('Attribution Heatmap')

plt.tight_layout()
plt.show()

This example creates a heatmap showing which regions of the image were most important for the model's classification decision.

Layer Attribution Methods

Sometimes, you want to understand the importance of specific layers or neurons. Captum provides methods for this as well:

python
from captum.attr import LayerGradCam

# Define the layer we want to analyze
layer = model.layer4[1].conv2

# Initialize LayerGradCam
layer_gc = LayerGradCam(model, layer)

# Calculate attributions
layer_attributions = layer_gc.attribute(input_tensor, target=232)

# Process the attributions for visualization
layer_attributions = layer_attributions.squeeze().cpu().detach().numpy()
layer_attributions = np.mean(layer_attributions, axis=0) # Average over channels

# Upsample attributions to the input size for visualization
from scipy.ndimage import zoom
upsampled_attrs = zoom(layer_attributions,
(input_tensor.shape[2] / layer_attributions.shape[0],
input_tensor.shape[3] / layer_attributions.shape[1]))

# Visualize
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')

plt.subplot(1, 2, 2)
plt.imshow(viz.visualize_image_attr(upsampled_attrs[:, :, None],
original_image,
method="heat_map",
sign="all",
show_colorbar=True))
plt.title('Layer GradCAM')

plt.tight_layout()
plt.show()

Comparing Different Attribution Methods

Captum offers numerous attribution methods, each with its strengths and weaknesses. Let's compare a few on the same input:

python
from captum.attr import (
Saliency,
IntegratedGradients,
DeepLift,
GradientShap,
NoiseTunnel
)

# Initialize attribution methods
saliency = Saliency(model)
integrated_gradients = IntegratedGradients(model)
deep_lift = DeepLift(model)

# Compute attributions
saliency_attrs = saliency.attribute(input_tensor, target=232)
ig_attrs = integrated_gradients.attribute(input_tensor, target=232, n_steps=50)

# Add noise tunnel for integrated gradients (helps smooth attributions)
nt = NoiseTunnel(integrated_gradients)
ig_nt_attrs = nt.attribute(input_tensor, target=232, nt_samples=10, nt_type='smoothgrad')

deep_lift_attrs = deep_lift.attribute(input_tensor, target=232)

# Visualize all methods
methods = [
("Saliency", saliency_attrs),
("Integrated Gradients", ig_attrs),
("IG with SmoothGrad", ig_nt_attrs),
("DeepLift", deep_lift_attrs)
]

plt.figure(figsize=(15, 10))
plt.subplot(2, 3, 1)
plt.imshow(original_image)
plt.title('Original Image')

for i, (name, attrs) in enumerate(methods, 2):
attrs = attrs.squeeze().cpu().detach().numpy()
attrs = np.transpose(attrs, (1, 2, 0))

plt.subplot(2, 3, i)
plt.imshow(viz.visualize_image_attr(attrs,
original_image,
method="heat_map",
sign="all"))
plt.title(name)

plt.tight_layout()
plt.show()

Text Model Interpretation with Captum

Captum works with text models too. Here's an example using BERT for sentiment classification:

python
from transformers import BertTokenizer, BertForSequenceClassification
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients

# Load pre-trained model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.eval()

# Prepare text input
text = "I really enjoyed this movie. The acting was fantastic!"
encoded_input = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
input_ids = encoded_input['input_ids']
attention_mask = encoded_input['attention_mask']

# Get prediction
output = model(input_ids, attention_mask=attention_mask)
pred_label = output.logits.argmax(dim=1).item()
pred_score = output.logits.softmax(dim=1)[0, pred_label].item()
print(f"Prediction: {'Positive' if pred_label == 1 else 'Negative'} ({pred_score:.3f})")

# Define a forward function for the attribution
def forward_func(input_ids, attention_mask=None):
outputs = model(input_ids, attention_mask=attention_mask)
return outputs.logits

# Define a function to get the baseline (reference)
def get_word_embeddings():
return model.bert.embeddings.word_embeddings.weight.detach()

# Initialize Layer Integrated Gradients
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)

# Get attributions
attributions, delta = lig.attribute(
inputs=input_ids,
baselines=torch.zeros_like(input_ids),
additional_forward_args=(attention_mask,),
target=pred_label,
return_convergence_delta=True,
attribute_to_layer_input=True
)

# Convert indices to words
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

# Process attributions
attributions = attributions.sum(dim=-1).squeeze(0)
attributions = attributions / torch.norm(attributions)
attributions = attributions.detach().numpy()

# Visualize word-level attributions
plt.figure(figsize=(12, 4))
plt.bar(tokens, attributions)
plt.xticks(rotation=45, ha='right')
plt.title('Word Attributions')
plt.tight_layout()
plt.show()

Real-World Application: Medical Diagnosis

One of the critical applications of model interpretability is in healthcare. Let's consider a simplified example where we use Captum to explain a model that predicts diabetes from patient data:

python
import torch
import torch.nn as nn
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from captum.attr import IntegratedGradients, FeatureAblation

# Load diabetes dataset
from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
X = diabetes.data
y = diabetes.target

# Convert to binary classification (above/below median)
y_binary = (y > y.mean()).astype(int)

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y_binary, test_size=0.2, random_state=42)

# Scale features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert to tensors
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.FloatTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.FloatTensor(y_test)

# Define a simple model
class DiabetesModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 8)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(8, 1)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
x = self.relu(self.linear1(x))
x = self.sigmoid(self.linear2(x))
return x

# Train the model
model = DiabetesModel()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

epochs = 100
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train_tensor).squeeze()
loss = criterion(outputs, y_train_tensor)
loss.backward()
optimizer.step()

if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Evaluate the model
model.eval()
with torch.no_grad():
test_outputs = model(X_test_tensor).squeeze()
test_preds = (test_outputs > 0.5).float()
accuracy = (test_preds == y_test_tensor).float().mean()
print(f"Test accuracy: {accuracy.item():.4f}")

# Select a sample to explain
sample_idx = 0
sample = X_test_tensor[sample_idx:sample_idx+1]
sample_output = model(sample).item()
sample_prediction = "High Risk" if sample_output > 0.5 else "Low Risk"
print(f"Prediction for sample: {sample_prediction} ({sample_output:.4f})")

# Use Integrated Gradients to explain prediction
ig = IntegratedGradients(model)
attributions = ig.attribute(sample, target=0, n_steps=50)

# Map attributions to feature names
feature_names = diabetes.feature_names
feature_importance = pd.DataFrame({
'Feature': feature_names,
'Importance': attributions.detach().numpy()[0]
})
feature_importance['abs_importance'] = feature_importance['Importance'].abs()
feature_importance = feature_importance.sort_values('abs_importance', ascending=False)

# Visualize feature importance
plt.figure(figsize=(10, 6))
colors = ['green' if x > 0 else 'red' for x in feature_importance['Importance']]
plt.barh(feature_importance['Feature'], feature_importance['Importance'], color=colors)
plt.title(f'Feature Importance for {sample_prediction} Prediction')
plt.xlabel('Attribution Score')
plt.tight_layout()
plt.show()

# Let's also try Feature Ablation for comparison
ablator = FeatureAblation(model)
ablation_attr = ablator.attribute(sample, target=0)

# Compare the methods
plt.figure(figsize=(12, 8))

plt.subplot(2, 1, 1)
plt.barh(feature_names, attributions.detach().numpy()[0])
plt.title('Integrated Gradients')

plt.subplot(2, 1, 2)
plt.barh(feature_names, ablation_attr.detach().numpy()[0])
plt.title('Feature Ablation')

plt.tight_layout()
plt.show()

This example shows how to use Captum to identify which features (like blood pressure, BMI, etc.) most influenced a diabetes risk prediction, which would be crucial for healthcare professionals to understand and trust the model's decisions.

Summary

In this tutorial, we've explored PyTorch Captum, a powerful library for model interpretability. We've covered:

  1. Basic attribution methods like Integrated Gradients
  2. Visualizing attributions for image classification models
  3. Layer-specific attribution methods
  4. Comparing different attribution techniques
  5. Text model interpretability
  6. A real-world application in healthcare

Understanding why models make specific predictions is becoming increasingly important, especially in high-stakes domains like healthcare, finance, and criminal justice. Captum provides the tools necessary to peek inside the "black box" of deep learning models and build trust in AI systems.

Additional Resources

Exercises

  1. Try using Captum to explain predictions of a model you've already built.
  2. Compare at least three different attribution methods on the same input and analyze the differences.
  3. Build a simple image classifier and use Captum to create an interactive web app that shows which parts of the image influence the model's decision.
  4. Apply Captum to a natural language processing task like sentiment analysis or text classification.
  5. Experiment with different baselines for Integrated Gradients and observe how they affect attributions.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)