Skip to main content

PyTorch Gradient Accumulation

When training deep learning models, you often need to use large batch sizes for better stability and convergence. However, GPU memory constraints can limit the batch size you can use. Gradient accumulation is a technique that allows you to simulate larger batch sizes by accumulating gradients across multiple smaller batches before updating the model parameters.

What is Gradient Accumulation?

Gradient accumulation is the process of:

  1. Computing the gradients for smaller mini-batches
  2. Accumulating (adding) these gradients over several iterations
  3. Performing a model parameter update only after accumulating gradients for the desired number of steps

This technique effectively simulates training with larger batch sizes while keeping memory requirements manageable.

Why Use Gradient Accumulation?

  • Memory Efficiency: Train with larger effective batch sizes without running out of GPU memory
  • Stability: Larger batch sizes can lead to more stable training
  • Hardware Flexibility: Train advanced models even on GPUs with limited memory capacity

Basic Implementation

Here's a step-by-step implementation of gradient accumulation in PyTorch:

python
def train_with_gradient_accumulation(model, train_loader, optimizer, criterion, accumulation_steps=4):
model.train()
total_loss = 0

# Set gradients to zero at the beginning
optimizer.zero_grad()

for i, (inputs, targets) in enumerate(train_loader):
# Forward pass
outputs = model(inputs)

# Calculate loss and normalize by accumulation steps
loss = criterion(outputs, targets) / accumulation_steps

# Backward pass (compute gradients)
loss.backward()

# Store the unnormalized loss for logging
total_loss += loss.item() * accumulation_steps

# Update parameters after accumulating gradients for specified number of steps
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

return total_loss / len(train_loader)

Let's break down what's happening:

  1. We initialize the optimizer by zeroing the gradients
  2. For each batch, we perform a forward pass and calculate the loss
  3. We divide the loss by the number of accumulation steps to normalize the gradients
  4. We perform a backward pass to accumulate gradients
  5. We only update the model parameters and zero the gradients after processing the specified number of accumulation steps

Complete Training Loop Example

Here's a complete example demonstrating gradient accumulation in a PyTorch training loop:

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout = nn.Dropout(0.25)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x

# Training function with gradient accumulation
def train(model, train_loader, optimizer, criterion, device, epochs=5, accumulation_steps=4):
model.train()

for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0

# Zero gradients at the beginning of each epoch
optimizer.zero_grad()

for i, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps

# Backward pass
loss.backward()

# Keep track of accuracy
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()

# Adjust the running loss (multiply to counteract the division earlier)
running_loss += loss.item() * accumulation_steps

# Update weights after accumulation_steps iterations
if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
optimizer.step()
optimizer.zero_grad()

epoch_loss = running_loss / len(train_loader)
epoch_acc = 100 * correct / total
print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')

print('Training complete!')

# Setup dataset, model, and training parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

model = SimpleModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train with gradient accumulation (effective batch size = 32 * 4 = 128)
train(model, train_loader, optimizer, criterion, device, epochs=5, accumulation_steps=4)

This example demonstrates training an MNIST classifier with a batch size of 32, but with an effective batch size of 128 using gradient accumulation over 4 steps.

Example Output

When running this code, you might see output similar to:

Epoch 1/5, Loss: 0.3518, Accuracy: 89.11%
Epoch 2/5, Loss: 0.1087, Accuracy: 96.77%
Epoch 3/5, Loss: 0.0721, Accuracy: 97.87%
Epoch 4/5, Loss: 0.0547, Accuracy: 98.34%
Epoch 5/5, Loss: 0.0435, Accuracy: 98.57%
Training complete!

Real-world Considerations

Adjusting Learning Rate

When using gradient accumulation to simulate larger batch sizes, you might need to adjust your learning rate accordingly. As a rule of thumb, if you increase your effective batch size by a factor of k, you might want to increase your learning rate by approximately √k.

python
# Example: Adjusting learning rate for gradient accumulation
base_lr = 0.001 # Learning rate for batch size of 32
accumulation_steps = 4 # Simulating batch size of 128 (32 * 4)
adjusted_lr = base_lr * (accumulation_steps ** 0.5) # Scale by square root of batch size ratio

optimizer = optim.Adam(model.parameters(), lr=adjusted_lr)

Mixed Precision Training with Gradient Accumulation

Gradient accumulation can be combined with mixed precision training for even more memory efficiency:

python
from torch.cuda.amp import autocast, GradScaler

def train_with_mixed_precision_and_accumulation(model, train_loader, optimizer, criterion,
device, accumulation_steps=4):
model.train()
scaler = GradScaler()

for i, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)

# Use autocast for mixed precision training
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps

# Scale loss and perform backward pass
scaler.scale(loss).backward()

if (i + 1) % accumulation_steps == 0:
# Unscale gradients before optimizer step
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

Distributed Training with Gradient Accumulation

When training across multiple GPUs, gradient accumulation can still be beneficial:

python
def distributed_train_with_accumulation(model, train_loader, optimizer, criterion, 
device, accumulation_steps=4):
model.train()

for i, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)

outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps

loss.backward()

if (i + 1) % accumulation_steps == 0:
# For distributed training, you might synchronize gradients here
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()

Common Pitfalls and Solutions

Inconsistent Batch Norm Statistics

With gradient accumulation, batch normalization statistics might not be as accurate since they're computed on smaller mini-batches. Solutions include:

  1. Using a larger momentum value in BatchNorm layers
  2. Using alternatives like Layer Normalization
  3. Freezing BatchNorm parameters during fine-tuning with gradient accumulation
python
# Freeze batch norm layers when using gradient accumulation
def freeze_batch_norm_layers(model):
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
module.eval() # Set to evaluation mode
module.weight.requires_grad = False
module.bias.requires_grad = False

Handling Uneven Batch Division

If your dataset size isn't cleanly divisible by your batch size and accumulation steps, you need to handle the last batch carefully:

python
# Handle the case when dataset size isn't divisible by batch size * accumulation_steps
if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
optimizer.step()
optimizer.zero_grad()

Summary

Gradient accumulation is a powerful technique for training deep learning models with large effective batch sizes on limited hardware. It works by:

  1. Processing smaller mini-batches to fit in GPU memory
  2. Accumulating gradients across multiple iterations
  3. Updating model parameters only after the desired number of accumulation steps

This approach allows you to train models that would otherwise require more powerful hardware, making deep learning more accessible and efficient.

Additional Resources

Exercises

  1. Implement gradient accumulation for a text classification task using a pre-trained transformer model.
  2. Compare the training dynamics (loss curve, accuracy) between standard training and training with gradient accumulation.
  3. Experiment with different combinations of batch sizes and accumulation steps to find the optimal setup for your hardware.
  4. Implement a learning rate scheduler that adjusts the learning rate based on the effective batch size when using gradient accumulation.
  5. Create a function that automatically determines the optimal accumulation steps based on the model size and available GPU memory.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)