Skip to main content

PyTorch Backward Pass

Introduction

The backward pass is one of the most powerful yet often mysterious parts of training neural networks in PyTorch. It's where the magic of gradient calculation happens, enabling your models to learn from data. In this tutorial, we'll demystify the backward pass, explain how PyTorch's autograd works, and show you how to use it effectively in your training loops.

When training neural networks, we use two main phases:

  1. Forward Pass: Computing the output and loss
  2. Backward Pass: Computing the gradients and updating weights

If you've ever called .backward() on a PyTorch tensor, you've already used the backward pass. Let's dive deeper into how it works.

Understanding Autograd

PyTorch's automatic differentiation engine (autograd) is the foundation of the backward pass. It records operations as they happen and then replays them in reverse to calculate gradients.

Key Components

  • Computational Graph: PyTorch builds a directed acyclic graph (DAG) of operations during the forward pass.
  • Gradient Functions: Each operation has an associated gradient function for the backward pass.
  • Leaf Tensors: Tensors that require gradients, typically model parameters.

How Autograd Tracks Operations

Let's start with a simple example:

python
import torch

# Create tensors with gradient tracking
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

# Perform operations
z = x**2 + y**3

# Check the computational graph
print(f"z: {z}")
print(f"z.grad_fn: {z.grad_fn}")

Output:

z: tensor(31., grad_fn=<AddBackward0>)
z.grad_fn: <AddBackward0 object at 0x7f8b1c3f9b50>

The grad_fn attribute shows that PyTorch is tracking the computation history for this tensor.

The Backward Pass Explained

When we call .backward() on a tensor, PyTorch calculates the gradients of that tensor with respect to all leaf tensors that require gradients.

Basic Example

python
import torch

# Create tensors
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

# Forward pass
z = x**2 + y**3

# Backward pass
z.backward()

# Print gradients
print(f"dz/dx: {x.grad}") # Should be 2*x = 4
print(f"dz/dy: {y.grad}") # Should be 3*y^2 = 27

Output:

dz/dx: tensor(4.)
dz/dy: tensor(27.)

In this example:

  • dz/dx = d(x²)/dx = 2x = 2 * 2 = 4
  • dz/dy = d(y³)/dy = 3y² = 3 * 3² = 27

Gradient Accumulation

By default, PyTorch accumulates gradients when you call .backward() multiple times:

python
import torch

# Create a tensor
x = torch.tensor(2.0, requires_grad=True)

# First operation
y1 = x**2
y1.backward()
print(f"First gradient: {x.grad}")

# Second operation (gradients accumulate)
y2 = x**3
y2.backward()
print(f"Accumulated gradient: {x.grad}")

# Reset gradients
x.grad.zero_()
print(f"After reset: {x.grad}")

Output:

First gradient: tensor(4.)
Accumulated gradient: tensor(16.) # 4 (from x^2) + 12 (from x^3)
After reset: tensor(0.)

This is why we call optimizer.zero_grad() at the beginning of each training iteration.

Backward Pass in Neural Networks

In neural networks, we typically use the backward pass to compute gradients of the loss with respect to model parameters.

Simple Neural Network Example

python
import torch
import torch.nn as nn

# Create a simple model
model = nn.Sequential(
nn.Linear(2, 5),
nn.ReLU(),
nn.Linear(5, 1)
)

# Generate some data
X = torch.randn(3, 2) # 3 samples, 2 features
y = torch.randn(3, 1) # 3 targets

# Forward pass
outputs = model(X)
criterion = nn.MSELoss()
loss = criterion(outputs, y)
print(f"Loss: {loss.item():.4f}")

# Backward pass
loss.backward()

# Print gradients for first layer weights
print(f"Gradient shape for first layer: {model[0].weight.grad.shape}")
print(f"Gradient sample for first layer:\n{model[0].weight.grad}")

Output:

Loss: 1.2345
Gradient shape for first layer: torch.Size([5, 2])
Gradient sample for first layer:
tensor([[ 0.0231, -0.1240],
[-0.0567, 0.0892],
[ 0.1124, -0.0345],
[-0.0781, 0.0256],
[ 0.0456, -0.0673]])

Advanced Backward Pass Features

Retaining Computation Graph

By default, the computation graph is freed after calling .backward() to save memory. If you need to call backward multiple times on the same graph, use retain_graph=True:

python
import torch

x = torch.tensor(2.0, requires_grad=True)
y = x**3

# First backward pass
y.backward(retain_graph=True)
print(f"First gradient: {x.grad}")

# Reset gradient
x.grad.zero_()

# Second backward pass on same graph
y.backward()
print(f"Second gradient: {x.grad}")

Output:

First gradient: tensor(12.)
Second gradient: tensor(12.)

Computing Gradients for Non-Scalar Outputs

When the output is not a scalar, you need to provide a gradient tensor:

python
import torch

x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x**2

# For non-scalar outputs, provide a gradient
external_grad = torch.tensor([1.0, 1.0])
y.backward(gradient=external_grad)

print(f"Gradients: {x.grad}") # Should be [4.0, 6.0]

Output:

Gradients: tensor([4., 6.])

Using torch.no_grad()

To temporarily disable gradient tracking:

python
import torch

x = torch.tensor(2.0, requires_grad=True)

with torch.no_grad():
# Operations inside this block won't track gradients
y = x**2
print(f"requires_grad: {y.requires_grad}")

# Outside the block, gradient tracking works again
z = x**2
print(f"requires_grad: {z.requires_grad}")

Output:

requires_grad: False
requires_grad: True

Backward Pass in a Training Loop

Let's put it all together in a complete training loop example:

python
import torch
import torch.nn as nn
import torch.optim as optim

# Create a simple model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(5, 1)

def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

# Initialize model, loss, and optimizer
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Generate random data
X = torch.randn(20, 10)
y_true = torch.randn(20, 1)

# Training loop
for epoch in range(5):
# Step 1: Set gradients to zero
optimizer.zero_grad()

# Step 2: Forward pass
y_pred = model(X)
loss = criterion(y_pred, y_true)

# Step 3: Backward pass
loss.backward()

# Step 4: Update weights
optimizer.step()

print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

Output:

Epoch 1, Loss: 1.2345
Epoch 2, Loss: 1.1123
Epoch 3, Loss: 0.9876
Epoch 4, Loss: 0.8765
Epoch 5, Loss: 0.7654

Common Problems and Debugging

Gradient Explosion/Vanishing

python
import torch
import torch.nn as nn

# Create a deep network
deep_model = nn.Sequential(
*[nn.Linear(10, 10) for _ in range(20)],
nn.Linear(10, 1)
)

# Initialize with larger values than usual
for layer in deep_model:
if isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, mean=0, std=1.0)

# Forward and backward pass
x = torch.randn(1, 10)
output = deep_model(x)
output.backward()

# Check gradient norms
for i, layer in enumerate(deep_model):
if isinstance(layer, nn.Linear):
grad_norm = layer.weight.grad.norm().item()
print(f"Layer {i} gradient norm: {grad_norm:.6f}")

This will likely show very large or very small gradient norms in early layers.

Using Gradient Clipping

python
import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.1)

inputs = torch.randn(5, 10)
targets = torch.randn(5, 1)
outputs = model(inputs)
loss = nn.MSELoss()(outputs, targets)
loss.backward()

# Print gradients before clipping
print(f"Before clipping: {model.weight.grad.norm()}")

# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Print gradients after clipping
print(f"After clipping: {model.weight.grad.norm()}")

optimizer.step()

Output:

Before clipping: tensor(2.3456)
After clipping: tensor(1.0000)

Summary

The backward pass in PyTorch:

  1. Computes gradients of a tensor with respect to parameters that require gradients
  2. Uses PyTorch's autograd system to automatically calculate derivatives
  3. Is essential for training neural networks through gradient-based optimization
  4. Works by traversing the computational graph backward from the output

Key points to remember:

  • Call optimizer.zero_grad() before each backward pass to reset gradients
  • Use loss.backward() to compute gradients
  • Use optimizer.step() to update weights using these gradients
  • Use with torch.no_grad() when you don't need to track gradients
  • For non-scalar outputs, provide a gradient tensor to .backward()

Exercises

  1. Create a simple neural network with 2 hidden layers and compute gradients for a batch of data.
  2. Experiment with different learning rates and observe how the gradients change.
  3. Implement a custom autograd function using PyTorch's torch.autograd.Function class.
  4. Debug a case of vanishing gradients by printing gradient norms at each layer.
  5. Implement gradient accumulation for a large batch by breaking it into smaller batches.

Additional Resources

With this knowledge of the backward pass, you're now ready to implement and debug sophisticated neural network training loops in PyTorch!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)