PyTorch Gradient Accumulation
When training deep learning models, you often need to use large batch sizes for better stability and convergence. However, GPU memory constraints can limit the batch size you can use. Gradient accumulation is a technique that allows you to simulate larger batch sizes by accumulating gradients across multiple smaller batches before updating the model parameters.
What is Gradient Accumulation?
Gradient accumulation is the process of:
- Computing the gradients for smaller mini-batches
- Accumulating (adding) these gradients over several iterations
- Performing a model parameter update only after accumulating gradients for the desired number of steps
This technique effectively simulates training with larger batch sizes while keeping memory requirements manageable.
Why Use Gradient Accumulation?
- Memory Efficiency: Train with larger effective batch sizes without running out of GPU memory
- Stability: Larger batch sizes can lead to more stable training
- Hardware Flexibility: Train advanced models even on GPUs with limited memory capacity
Basic Implementation
Here's a step-by-step implementation of gradient accumulation in PyTorch:
def train_with_gradient_accumulation(model, train_loader, optimizer, criterion, accumulation_steps=4):
model.train()
total_loss = 0
# Set gradients to zero at the beginning
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(train_loader):
# Forward pass
outputs = model(inputs)
# Calculate loss and normalize by accumulation steps
loss = criterion(outputs, targets) / accumulation_steps
# Backward pass (compute gradients)
loss.backward()
# Store the unnormalized loss for logging
total_loss += loss.item() * accumulation_steps
# Update parameters after accumulating gradients for specified number of steps
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
return total_loss / len(train_loader)
Let's break down what's happening:
- We initialize the optimizer by zeroing the gradients
- For each batch, we perform a forward pass and calculate the loss
- We divide the loss by the number of accumulation steps to normalize the gradients
- We perform a backward pass to accumulate gradients
- We only update the model parameters and zero the gradients after processing the specified number of accumulation steps
Complete Training Loop Example
Here's a complete example demonstrating gradient accumulation in a PyTorch training loop:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout = nn.Dropout(0.25)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
# Training function with gradient accumulation
def train(model, train_loader, optimizer, criterion, device, epochs=5, accumulation_steps=4):
model.train()
for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
# Zero gradients at the beginning of each epoch
optimizer.zero_grad()
for i, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps
# Backward pass
loss.backward()
# Keep track of accuracy
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
# Adjust the running loss (multiply to counteract the division earlier)
running_loss += loss.item() * accumulation_steps
# Update weights after accumulation_steps iterations
if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
optimizer.step()
optimizer.zero_grad()
epoch_loss = running_loss / len(train_loader)
epoch_acc = 100 * correct / total
print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')
print('Training complete!')
# Setup dataset, model, and training parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
model = SimpleModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train with gradient accumulation (effective batch size = 32 * 4 = 128)
train(model, train_loader, optimizer, criterion, device, epochs=5, accumulation_steps=4)
This example demonstrates training an MNIST classifier with a batch size of 32, but with an effective batch size of 128 using gradient accumulation over 4 steps.
Example Output
When running this code, you might see output similar to:
Epoch 1/5, Loss: 0.3518, Accuracy: 89.11%
Epoch 2/5, Loss: 0.1087, Accuracy: 96.77%
Epoch 3/5, Loss: 0.0721, Accuracy: 97.87%
Epoch 4/5, Loss: 0.0547, Accuracy: 98.34%
Epoch 5/5, Loss: 0.0435, Accuracy: 98.57%
Training complete!
Real-world Considerations
Adjusting Learning Rate
When using gradient accumulation to simulate larger batch sizes, you might need to adjust your learning rate accordingly. As a rule of thumb, if you increase your effective batch size by a factor of k
, you might want to increase your learning rate by approximately √k
.
# Example: Adjusting learning rate for gradient accumulation
base_lr = 0.001 # Learning rate for batch size of 32
accumulation_steps = 4 # Simulating batch size of 128 (32 * 4)
adjusted_lr = base_lr * (accumulation_steps ** 0.5) # Scale by square root of batch size ratio
optimizer = optim.Adam(model.parameters(), lr=adjusted_lr)
Mixed Precision Training with Gradient Accumulation
Gradient accumulation can be combined with mixed precision training for even more memory efficiency:
from torch.cuda.amp import autocast, GradScaler
def train_with_mixed_precision_and_accumulation(model, train_loader, optimizer, criterion,
device, accumulation_steps=4):
model.train()
scaler = GradScaler()
for i, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
# Use autocast for mixed precision training
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps
# Scale loss and perform backward pass
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
# Unscale gradients before optimizer step
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Distributed Training with Gradient Accumulation
When training across multiple GPUs, gradient accumulation can still be beneficial:
def distributed_train_with_accumulation(model, train_loader, optimizer, criterion,
device, accumulation_steps=4):
model.train()
for i, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
# For distributed training, you might synchronize gradients here
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
Common Pitfalls and Solutions
Inconsistent Batch Norm Statistics
With gradient accumulation, batch normalization statistics might not be as accurate since they're computed on smaller mini-batches. Solutions include:
- Using a larger momentum value in BatchNorm layers
- Using alternatives like Layer Normalization
- Freezing BatchNorm parameters during fine-tuning with gradient accumulation
# Freeze batch norm layers when using gradient accumulation
def freeze_batch_norm_layers(model):
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
module.eval() # Set to evaluation mode
module.weight.requires_grad = False
module.bias.requires_grad = False
Handling Uneven Batch Division
If your dataset size isn't cleanly divisible by your batch size and accumulation steps, you need to handle the last batch carefully:
# Handle the case when dataset size isn't divisible by batch size * accumulation_steps
if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
optimizer.step()
optimizer.zero_grad()
Summary
Gradient accumulation is a powerful technique for training deep learning models with large effective batch sizes on limited hardware. It works by:
- Processing smaller mini-batches to fit in GPU memory
- Accumulating gradients across multiple iterations
- Updating model parameters only after the desired number of accumulation steps
This approach allows you to train models that would otherwise require more powerful hardware, making deep learning more accessible and efficient.
Additional Resources
- PyTorch Documentation
- Gradient Accumulation Implementation in Hugging Face Transformers
- Effect of Batch Size on Training Dynamics
Exercises
- Implement gradient accumulation for a text classification task using a pre-trained transformer model.
- Compare the training dynamics (loss curve, accuracy) between standard training and training with gradient accumulation.
- Experiment with different combinations of batch sizes and accumulation steps to find the optimal setup for your hardware.
- Implement a learning rate scheduler that adjusts the learning rate based on the effective batch size when using gradient accumulation.
- Create a function that automatically determines the optimal accumulation steps based on the model size and available GPU memory.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)