Skip to main content

PyTorch Gradient Checkpointing

Introduction

When training large neural networks, especially deep ones, you might encounter the dreaded "CUDA out of memory" error. This happens because PyTorch's autograd engine stores all intermediate activations from the forward pass to calculate gradients during the backward pass.

Gradient checkpointing is a technique that trades computation for memory by:

  • Skipping the storage of some intermediate activations during the forward pass
  • Recomputing them on-demand during the backward pass

This memory-saving technique can be critical for training large models on limited GPU resources. In this tutorial, we'll explore how gradient checkpointing works in PyTorch and how you can use it in your projects.

Understanding the Memory Problem

Before diving into gradient checkpointing, let's understand why memory becomes an issue during training.

The Standard PyTorch Workflow

In a typical neural network training workflow:

  1. Forward pass: Compute the output of the model, storing all intermediate tensors
  2. Loss calculation: Compute the loss based on the model output and target
  3. Backward pass: Compute gradients of the loss with respect to parameters

The autograd engine needs to store all intermediate activations from the forward pass to correctly compute gradients during the backward pass. For deep networks, this can consume enormous amounts of memory.

python
import torch
import torch.nn as nn

# Example model with many layers
class DeepModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(1024, 1024) for _ in range(100)
])
self.activation = nn.ReLU()

def forward(self, x):
for layer in self.layers:
# Each intermediate result is stored in memory for backward pass
x = self.activation(layer(x))
return x

# Create a large input tensor
input_tensor = torch.randn(128, 1024, device="cuda")
model = DeepModel().cuda()

# This might cause OOM (Out of Memory) for large models
output = model(input_tensor)
loss = output.sum()
loss.backward() # All activations are stored until this point

Gradient Checkpointing to the Rescue

Gradient checkpointing works by saving only some strategically chosen activations during the forward pass. When the backward pass needs a missing activation, it's recomputed from the nearest saved checkpoint.

How It Works

  1. Divide your model into segments
  2. Save activations only at segment boundaries
  3. During backward pass, recompute activations within each segment as needed

The trade-off is simple: you perform some extra computation (a partial forward pass) to save memory.

Implementing Gradient Checkpointing in PyTorch

PyTorch provides a convenient function called checkpoint in the torch.utils.checkpoint module.

Basic Usage

python
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class CheckpointedModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(1024, 1024) for _ in range(100)
])
self.activation = nn.ReLU()

def forward(self, x):
for i, layer in enumerate(self.layers):
if i % 10 == 0: # Apply checkpointing every 10 layers
x = checkpoint.checkpoint(self.layer_block, x, i)
else:
x = self.activation(layer(x))
return x

def layer_block(self, x, idx):
# This function will be recomputed during backward pass
return self.activation(self.layers[idx](x))

# Usage
model = CheckpointedModel().cuda()
input_tensor = torch.randn(128, 1024, device="cuda")
output = model(input_tensor)
loss = output.sum()
loss.backward() # Some activations will be recomputed

Checkpointing Sequential Blocks

A common approach is to checkpoint entire blocks of sequential operations:

python
class EfficientModel(nn.Module):
def __init__(self):
super().__init__()
# Create 10 blocks with 10 layers each
self.blocks = nn.ModuleList([
self.create_block() for _ in range(10)
])

def create_block(self):
layers = []
for _ in range(10):
layers.append(nn.Linear(1024, 1024))
layers.append(nn.ReLU())
return nn.Sequential(*layers)

def forward(self, x):
for block in self.blocks:
# Checkpoint at the block level
x = checkpoint.checkpoint(block, x)
return x

# This model can train with much less memory
model = EfficientModel().cuda()

Real-World Example: Training a Transformer Model

Let's see how gradient checkpointing can be applied to a transformer model, which typically requires substantial memory due to the self-attention mechanism and deep architecture.

python
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class CheckpointedTransformer(nn.Module):
def __init__(self, vocab_size=10000, d_model=512, n_head=8, num_layers=12):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
encoder_layers = TransformerEncoderLayer(d_model, n_head, dim_feedforward=2048)
self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
self.d_model = d_model

# Enable gradient checkpointing on the transformer encoder
self.use_checkpointing = True

def forward(self, src):
# src shape: [seq_len, batch_size]
src = self.embedding(src) * math.sqrt(self.d_model)

if self.use_checkpointing and self.training:
# Apply checkpointing to the transformer encoder
output = checkpoint.checkpoint(self.transformer_encoder, src)
else:
output = self.transformer_encoder(src)

output = self.output_layer(output)
return output

# Create a model and some sample data
model = CheckpointedTransformer().cuda()
src_data = torch.randint(0, 10000, (100, 32)).cuda() # [seq_len, batch_size]

# Training loop with checkpointing enabled
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

# Forward pass with reduced memory usage
outputs = model(src_data)
loss = criterion(outputs.view(-1, 10000), src_data.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()

In this example, we apply checkpointing to the entire transformer encoder module. For very deep transformers like GPT models, you might want to checkpoint individual encoder layers instead.

Advanced Techniques and Tips

Selective Checkpointing

You don't need to checkpoint every part of your model. Focus on the memory-intensive parts:

python
def forward(self, x):
# Regular processing for lightweight layers
x = self.initial_layers(x)

# Checkpoint memory-intensive blocks
for block in self.memory_intensive_blocks:
x = checkpoint.checkpoint(block, x)

# Regular processing for final layers
x = self.final_layers(x)
return x

Nested Checkpointing

For extremely large models, you can use nested checkpointing:

python
def forward(self, x):
for super_block in self.super_blocks:
# First level of checkpointing
x = checkpoint.checkpoint(self.process_super_block, super_block, x)
return x

def process_super_block(self, super_block, x):
for block in super_block:
# Second level of checkpointing
x = checkpoint.checkpoint(block, x)
return x

Memory Usage Comparison

Here's an example comparing memory usage with and without checkpointing:

python
def check_memory_usage():
model_standard = DeepModel().cuda()
model_checkpointed = CheckpointedModel().cuda()

# Clear cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Test standard model
input_tensor = torch.randn(128, 1024, device="cuda")
output = model_standard(input_tensor)
loss = output.sum()
loss.backward()
standard_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB

# Clear cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Test checkpointed model
input_tensor = torch.randn(128, 1024, device="cuda")
output = model_checkpointed(input_tensor)
loss = output.sum()
loss.backward()
checkpointed_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB

print(f"Standard model memory usage: {standard_memory:.2f} GB")
print(f"Checkpointed model memory usage: {checkpointed_memory:.2f} GB")
print(f"Memory saved: {standard_memory - checkpointed_memory:.2f} GB")

Using Checkpointing with Existing Models

You can also apply checkpointing to existing models by wrapping their forward methods:

python
import torchvision.models as models
import torch.utils.checkpoint as checkpoint

# Load a pre-trained ResNet
original_model = models.resnet101(pretrained=True).cuda()

# Create a wrapper with checkpointing
class CheckpointedResNet(nn.Module):
def __init__(self, original_model):
super().__init__()
self.model = original_model

# Store references to key modules
self.conv1 = self.model.conv1
self.bn1 = self.model.bn1
self.relu = self.model.relu
self.maxpool = self.model.maxpool

self.layer1 = self.model.layer1
self.layer2 = self.model.layer2
self.layer3 = self.model.layer3
self.layer4 = self.model.layer4

self.avgpool = self.model.avgpool
self.fc = self.model.fc

def forward(self, x):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))

# Apply checkpointing to the most memory-intensive layers
x = checkpoint.checkpoint(self.layer1, x)
x = checkpoint.checkpoint(self.layer2, x)
x = checkpoint.checkpoint(self.layer3, x)
x = checkpoint.checkpoint(self.layer4, x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x

# Create the checkpointed version
efficient_model = CheckpointedResNet(original_model)

Common Pitfalls and Solutions

RNG States

When using checkpointing with operations that use random number generators (like dropout), you need to be careful about the RNG state. PyTorch's checkpoint function handles this automatically, but be aware that the recomputed forward pass might use different random numbers than the original.

Non-Differentiable Functions

If your checkpointed function contains non-differentiable operations, autograd might fail. Ensure all operations in the checkpointed function are differentiable.

Data Dependence

The functions you checkpoint should only depend on the input tensors and model parameters, not on global variables or other changing state.

Performance Considerations

While checkpointing saves memory, it increases computation time. Here's a simple benchmark:

python
import time

def benchmark_checkpointing():
model_standard = DeepModel().cuda()
model_checkpointed = CheckpointedModel().cuda()
input_tensor = torch.randn(128, 1024, device="cuda")

# Benchmark standard model
start = time.time()
for _ in range(10):
output = model_standard(input_tensor)
loss = output.sum()
loss.backward()
standard_time = (time.time() - start) / 10

# Benchmark checkpointed model
start = time.time()
for _ in range(10):
output = model_checkpointed(input_tensor)
loss = output.sum()
loss.backward()
checkpointed_time = (time.time() - start) / 10

print(f"Standard model average time: {standard_time:.4f}s")
print(f"Checkpointed model average time: {checkpointed_time:.4f}s")
print(f"Slowdown factor: {checkpointed_time/standard_time:.2f}x")

Typically, you might see a slowdown factor of 1.2x to 1.5x, which is often a worthwhile trade-off for the memory savings.

Summary

Gradient checkpointing is a powerful technique for training large neural networks with limited memory resources. By strategically recomputing activations during the backward pass instead of storing them all during the forward pass, you can significantly reduce memory usage at the cost of additional computation.

Key takeaways:

  1. Gradient checkpointing trades computation for memory by recomputing activations
  2. It's easy to implement in PyTorch using torch.utils.checkpoint.checkpoint
  3. Apply checkpointing selectively to memory-intensive parts of your model
  4. Expect a moderate increase in training time (20-50% typically)
  5. It's crucial for training very deep networks or when using limited GPU resources

Additional Resources

Exercises

  1. Implement gradient checkpointing on a simple convolutional neural network and measure the memory savings.
  2. Modify the checkpointed transformer example to apply checkpointing at the individual encoder layer level instead of the entire encoder.
  3. Experiment with different checkpointing frequencies (e.g., every 2 layers, every 5 layers) and analyze the trade-off between memory usage and computation time.
  4. Implement checkpointing for a U-Net model for image segmentation, focusing on the decoder part where most memory is used.
  5. Create a custom checkpointing strategy that dynamically decides which layers to checkpoint based on their memory usage.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)