PyTorch Gradient Checkpointing
Introduction
When training large neural networks, especially deep ones, you might encounter the dreaded "CUDA out of memory" error. This happens because PyTorch's autograd engine stores all intermediate activations from the forward pass to calculate gradients during the backward pass.
Gradient checkpointing is a technique that trades computation for memory by:
- Skipping the storage of some intermediate activations during the forward pass
- Recomputing them on-demand during the backward pass
This memory-saving technique can be critical for training large models on limited GPU resources. In this tutorial, we'll explore how gradient checkpointing works in PyTorch and how you can use it in your projects.
Understanding the Memory Problem
Before diving into gradient checkpointing, let's understand why memory becomes an issue during training.
The Standard PyTorch Workflow
In a typical neural network training workflow:
- Forward pass: Compute the output of the model, storing all intermediate tensors
- Loss calculation: Compute the loss based on the model output and target
- Backward pass: Compute gradients of the loss with respect to parameters
The autograd engine needs to store all intermediate activations from the forward pass to correctly compute gradients during the backward pass. For deep networks, this can consume enormous amounts of memory.
import torch
import torch.nn as nn
# Example model with many layers
class DeepModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(1024, 1024) for _ in range(100)
])
self.activation = nn.ReLU()
def forward(self, x):
for layer in self.layers:
# Each intermediate result is stored in memory for backward pass
x = self.activation(layer(x))
return x
# Create a large input tensor
input_tensor = torch.randn(128, 1024, device="cuda")
model = DeepModel().cuda()
# This might cause OOM (Out of Memory) for large models
output = model(input_tensor)
loss = output.sum()
loss.backward() # All activations are stored until this point
Gradient Checkpointing to the Rescue
Gradient checkpointing works by saving only some strategically chosen activations during the forward pass. When the backward pass needs a missing activation, it's recomputed from the nearest saved checkpoint.
How It Works
- Divide your model into segments
- Save activations only at segment boundaries
- During backward pass, recompute activations within each segment as needed
The trade-off is simple: you perform some extra computation (a partial forward pass) to save memory.
Implementing Gradient Checkpointing in PyTorch
PyTorch provides a convenient function called checkpoint
in the torch.utils.checkpoint
module.
Basic Usage
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
class CheckpointedModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(1024, 1024) for _ in range(100)
])
self.activation = nn.ReLU()
def forward(self, x):
for i, layer in enumerate(self.layers):
if i % 10 == 0: # Apply checkpointing every 10 layers
x = checkpoint.checkpoint(self.layer_block, x, i)
else:
x = self.activation(layer(x))
return x
def layer_block(self, x, idx):
# This function will be recomputed during backward pass
return self.activation(self.layers[idx](x))
# Usage
model = CheckpointedModel().cuda()
input_tensor = torch.randn(128, 1024, device="cuda")
output = model(input_tensor)
loss = output.sum()
loss.backward() # Some activations will be recomputed
Checkpointing Sequential Blocks
A common approach is to checkpoint entire blocks of sequential operations:
class EfficientModel(nn.Module):
def __init__(self):
super().__init__()
# Create 10 blocks with 10 layers each
self.blocks = nn.ModuleList([
self.create_block() for _ in range(10)
])
def create_block(self):
layers = []
for _ in range(10):
layers.append(nn.Linear(1024, 1024))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def forward(self, x):
for block in self.blocks:
# Checkpoint at the block level
x = checkpoint.checkpoint(block, x)
return x
# This model can train with much less memory
model = EfficientModel().cuda()
Real-World Example: Training a Transformer Model
Let's see how gradient checkpointing can be applied to a transformer model, which typically requires substantial memory due to the self-attention mechanism and deep architecture.
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class CheckpointedTransformer(nn.Module):
def __init__(self, vocab_size=10000, d_model=512, n_head=8, num_layers=12):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
encoder_layers = TransformerEncoderLayer(d_model, n_head, dim_feedforward=2048)
self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
self.d_model = d_model
# Enable gradient checkpointing on the transformer encoder
self.use_checkpointing = True
def forward(self, src):
# src shape: [seq_len, batch_size]
src = self.embedding(src) * math.sqrt(self.d_model)
if self.use_checkpointing and self.training:
# Apply checkpointing to the transformer encoder
output = checkpoint.checkpoint(self.transformer_encoder, src)
else:
output = self.transformer_encoder(src)
output = self.output_layer(output)
return output
# Create a model and some sample data
model = CheckpointedTransformer().cuda()
src_data = torch.randint(0, 10000, (100, 32)).cuda() # [seq_len, batch_size]
# Training loop with checkpointing enabled
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
# Forward pass with reduced memory usage
outputs = model(src_data)
loss = criterion(outputs.view(-1, 10000), src_data.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
In this example, we apply checkpointing to the entire transformer encoder module. For very deep transformers like GPT models, you might want to checkpoint individual encoder layers instead.
Advanced Techniques and Tips
Selective Checkpointing
You don't need to checkpoint every part of your model. Focus on the memory-intensive parts:
def forward(self, x):
# Regular processing for lightweight layers
x = self.initial_layers(x)
# Checkpoint memory-intensive blocks
for block in self.memory_intensive_blocks:
x = checkpoint.checkpoint(block, x)
# Regular processing for final layers
x = self.final_layers(x)
return x
Nested Checkpointing
For extremely large models, you can use nested checkpointing:
def forward(self, x):
for super_block in self.super_blocks:
# First level of checkpointing
x = checkpoint.checkpoint(self.process_super_block, super_block, x)
return x
def process_super_block(self, super_block, x):
for block in super_block:
# Second level of checkpointing
x = checkpoint.checkpoint(block, x)
return x
Memory Usage Comparison
Here's an example comparing memory usage with and without checkpointing:
def check_memory_usage():
model_standard = DeepModel().cuda()
model_checkpointed = CheckpointedModel().cuda()
# Clear cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Test standard model
input_tensor = torch.randn(128, 1024, device="cuda")
output = model_standard(input_tensor)
loss = output.sum()
loss.backward()
standard_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB
# Clear cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Test checkpointed model
input_tensor = torch.randn(128, 1024, device="cuda")
output = model_checkpointed(input_tensor)
loss = output.sum()
loss.backward()
checkpointed_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB
print(f"Standard model memory usage: {standard_memory:.2f} GB")
print(f"Checkpointed model memory usage: {checkpointed_memory:.2f} GB")
print(f"Memory saved: {standard_memory - checkpointed_memory:.2f} GB")
Using Checkpointing with Existing Models
You can also apply checkpointing to existing models by wrapping their forward methods:
import torchvision.models as models
import torch.utils.checkpoint as checkpoint
# Load a pre-trained ResNet
original_model = models.resnet101(pretrained=True).cuda()
# Create a wrapper with checkpointing
class CheckpointedResNet(nn.Module):
def __init__(self, original_model):
super().__init__()
self.model = original_model
# Store references to key modules
self.conv1 = self.model.conv1
self.bn1 = self.model.bn1
self.relu = self.model.relu
self.maxpool = self.model.maxpool
self.layer1 = self.model.layer1
self.layer2 = self.model.layer2
self.layer3 = self.model.layer3
self.layer4 = self.model.layer4
self.avgpool = self.model.avgpool
self.fc = self.model.fc
def forward(self, x):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
# Apply checkpointing to the most memory-intensive layers
x = checkpoint.checkpoint(self.layer1, x)
x = checkpoint.checkpoint(self.layer2, x)
x = checkpoint.checkpoint(self.layer3, x)
x = checkpoint.checkpoint(self.layer4, x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Create the checkpointed version
efficient_model = CheckpointedResNet(original_model)
Common Pitfalls and Solutions
RNG States
When using checkpointing with operations that use random number generators (like dropout), you need to be careful about the RNG state. PyTorch's checkpoint function handles this automatically, but be aware that the recomputed forward pass might use different random numbers than the original.
Non-Differentiable Functions
If your checkpointed function contains non-differentiable operations, autograd might fail. Ensure all operations in the checkpointed function are differentiable.
Data Dependence
The functions you checkpoint should only depend on the input tensors and model parameters, not on global variables or other changing state.
Performance Considerations
While checkpointing saves memory, it increases computation time. Here's a simple benchmark:
import time
def benchmark_checkpointing():
model_standard = DeepModel().cuda()
model_checkpointed = CheckpointedModel().cuda()
input_tensor = torch.randn(128, 1024, device="cuda")
# Benchmark standard model
start = time.time()
for _ in range(10):
output = model_standard(input_tensor)
loss = output.sum()
loss.backward()
standard_time = (time.time() - start) / 10
# Benchmark checkpointed model
start = time.time()
for _ in range(10):
output = model_checkpointed(input_tensor)
loss = output.sum()
loss.backward()
checkpointed_time = (time.time() - start) / 10
print(f"Standard model average time: {standard_time:.4f}s")
print(f"Checkpointed model average time: {checkpointed_time:.4f}s")
print(f"Slowdown factor: {checkpointed_time/standard_time:.2f}x")
Typically, you might see a slowdown factor of 1.2x to 1.5x, which is often a worthwhile trade-off for the memory savings.
Summary
Gradient checkpointing is a powerful technique for training large neural networks with limited memory resources. By strategically recomputing activations during the backward pass instead of storing them all during the forward pass, you can significantly reduce memory usage at the cost of additional computation.
Key takeaways:
- Gradient checkpointing trades computation for memory by recomputing activations
- It's easy to implement in PyTorch using
torch.utils.checkpoint.checkpoint
- Apply checkpointing selectively to memory-intensive parts of your model
- Expect a moderate increase in training time (20-50% typically)
- It's crucial for training very deep networks or when using limited GPU resources
Additional Resources
- PyTorch Documentation on Checkpointing
- Memory-Efficient Implementation of DenseNets paper - Research that inspired gradient checkpointing
- Customizing gradient checkpointing in Hugging Face Transformers
Exercises
- Implement gradient checkpointing on a simple convolutional neural network and measure the memory savings.
- Modify the checkpointed transformer example to apply checkpointing at the individual encoder layer level instead of the entire encoder.
- Experiment with different checkpointing frequencies (e.g., every 2 layers, every 5 layers) and analyze the trade-off between memory usage and computation time.
- Implement checkpointing for a U-Net model for image segmentation, focusing on the decoder part where most memory is used.
- Create a custom checkpointing strategy that dynamically decides which layers to checkpoint based on their memory usage.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)