PyTorch Custom Autograd
Introduction
PyTorch's autograd system is a powerful engine for automatic differentiation that handles the calculation of gradients for you. While PyTorch provides a wide range of built-in operations with pre-defined gradients, sometimes you might need to:
- Implement a custom operation not available in PyTorch
- Optimize a specific computation for better performance
- Define a more numerically stable gradient calculation
- Create operations with custom behaviors during the backward pass
In this tutorial, we'll learn how to extend PyTorch's autograd system by creating custom autograd functions. This allows you to seamlessly integrate your own operations with PyTorch's automatic differentiation capabilities.
Understanding Custom Autograd Functions
PyTorch allows you to define custom operations by extending the torch.autograd.Function
class. There are two key methods to implement:
- forward(): Defines the computation performed in the forward pass
- backward(): Defines how gradients are computed during the backward pass
Let's start with a simple example to understand the structure.
Basic Structure of a Custom Autograd Function
Here's a template for creating a custom autograd function:
import torch
class MyCustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2, ...):
# Perform the operation
output = ...
# Save variables needed for backward
ctx.save_for_backward(input1, input2, ...)
ctx.other_data = ... # Save non-tensor data if needed
return output
@staticmethod
def backward(ctx, grad_output):
# Retrieve saved tensors and data
input1, input2, ... = ctx.saved_tensors
other_data = ctx.other_data
# Calculate gradients
grad_input1 = ...
grad_input2 = ...
# Return gradients for each input in the same order as forward
return grad_input1, grad_input2, ...
Let's break down what's happening:
- The
ctx
parameter is a context object that is used to store information for the backward pass save_for_backward()
stores tensors that will be needed during the backward passbackward()
must return the same number of gradients as there are inputs toforward()
- Each gradient should have the same shape as its corresponding input
Simple Example: Custom ReLU Function
Let's implement a custom ReLU (Rectified Linear Unit) function to get a better understanding:
import torch
class CustomReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor):
ctx.save_for_backward(input_tensor)
return input_tensor.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input_tensor, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input_tensor < 0] = 0
return grad_input
# To use the custom function:
custom_relu = CustomReLU.apply
Let's see how to use this function:
# Create an input tensor that requires gradients
x = torch.randn(5, requires_grad=True)
print(f"Input: {x}")
# Apply our custom ReLU
y = custom_relu(x)
print(f"Output: {y}")
# Compute gradients
y.sum().backward()
print(f"Gradient of x: {x.grad}")
Example output:
Input: tensor([-1.2032, 0.8964, 0.0677, -0.5543, 1.1631], requires_grad=True)
Output: tensor([0.0000, 0.8964, 0.0677, 0.0000, 1.1631], grad_fn=<CustomReLUBackward>)
Gradient of x: tensor([0., 1., 1., 0., 1.])
Notice how the gradient is 1 for positive inputs and 0 for negative inputs, which is exactly the derivative of the ReLU function!
More Complex Example: Custom Exponential Linear Unit (ELU)
Let's implement a more complex function, the Exponential Linear Unit (ELU):
class CustomELU(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, alpha=1.0):
ctx.alpha = alpha
ctx.save_for_backward(input_tensor)
output = input_tensor.clone()
negative_indices = input_tensor < 0
output[negative_indices] = alpha * (torch.exp(input_tensor[negative_indices]) - 1)
return output
@staticmethod
def backward(ctx, grad_output):
input_tensor, = ctx.saved_tensors
alpha = ctx.alpha
grad_input = grad_output.clone()
negative_indices = input_tensor < 0
grad_input[negative_indices] = grad_output[negative_indices] * (
alpha * torch.exp(input_tensor[negative_indices])
)
# We return None for the gradient of alpha since it's a parameter, not an input
return grad_input, None
# To use the custom function:
custom_elu = CustomELU.apply
Let's use our custom ELU function:
x = torch.randn(5, requires_grad=True)
alpha = 0.5
print(f"Input: {x}")
# Apply our custom ELU
y = custom_elu(x, alpha)
print(f"Custom ELU output: {y}")
# Compare with PyTorch's built-in ELU
z = torch.nn.functional.elu(x, alpha)
print(f"PyTorch's ELU output: {z}")
# Check if they're the same
print(f"Outputs match: {torch.allclose(y, z)}")
# Compute gradients
y.sum().backward()
print(f"Gradient of x: {x.grad}")
Example output:
Input: tensor([-0.7526, 0.9517, -1.2040, 0.8450, -0.4021], requires_grad=True)
Custom ELU output: tensor([-0.2642, 0.9517, -0.3496, 0.8450, -0.1655], grad_fn=<CustomELUBackward>)
PyTorch's ELU output: tensor([-0.2642, 0.9517, -0.3496, 0.8450, -0.1655], grad_fn=<EluBackward0>)
Outputs match: True
Gradient of x: tensor([0.2359, 1.0000, 0.1503, 1.0000, 0.3345])
Custom Autograd for Numerical Stability
One key advantage of custom autograd functions is implementing numerically stable operations. Let's see an example of a log-sum-exp operation:
class LogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, dim=0):
# For numerical stability, we subtract the max value before exponentiating
max_val, _ = torch.max(input_tensor, dim=dim, keepdim=True)
output = max_val + (input_tensor - max_val).exp().sum(dim=dim, keepdim=True).log()
# If keepdim is False, squeeze the reduced dimension
if not isinstance(dim, tuple):
dim = (dim,)
ctx.dim = dim
ctx.save_for_backward(input_tensor, output)
return output
@staticmethod
def backward(ctx, grad_output):
input_tensor, output = ctx.saved_tensors
dim = ctx.dim
# Compute gradient
grad_input = grad_output * torch.exp(input_tensor - output)
# Return None for the gradient of dim since it's not a tensor
return grad_input, None
# To use the custom function:
log_sum_exp = LogSumExp.apply
Here's how to use it:
# Create a tensor with some large values
x = torch.tensor([[1000., 1000.], [1000., 1000.]], requires_grad=True)
# Using our custom log-sum-exp
result_custom = log_sum_exp(x, dim=1)
print(f"Custom LogSumExp: {result_custom}")
# Using a naive implementation (can cause numerical overflow)
try:
result_naive = torch.log(torch.exp(x).sum(dim=1, keepdim=True))
print(f"Naive implementation: {result_naive}")
except RuntimeError as e:
print(f"Error with naive implementation: {e}")
# Test backpropagation
result_custom.sum().backward()
print(f"Gradient of x: {x.grad}")
Example output:
Custom LogSumExp: tensor([[1000.6931],
[1000.6931]], grad_fn=<LogSumExpBackward>)
Error with naive implementation: RuntimeError: inf values detected
Gradient of x: tensor([[0.5000, 0.5000],
[0.5000, 0.5000]])
Our custom implementation can handle large values without numerical overflow, while the naive approach fails.
Real-World Application: Custom Loss Function
Let's implement a custom focal loss function, which is commonly used for imbalanced classification problems:
class FocalLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, predictions, targets, gamma=2.0, alpha=0.25):
"""
Focal loss for binary classification
predictions: sigmoid outputs
targets: binary labels (0 or 1)
gamma: focusing parameter
alpha: balancing parameter
"""
ctx.gamma = gamma
ctx.alpha = alpha
ctx.save_for_backward(predictions, targets)
# Calculate binary cross entropy
bce = -targets * torch.log(predictions) - (1 - targets) * torch.log(1 - predictions)
# Apply focal loss formula
pt = torch.where(targets == 1, predictions, 1 - predictions)
focal_weight = (1 - pt) ** gamma
# Apply alpha weighting
alpha_weight = torch.where(targets == 1, alpha, 1 - alpha)
loss = alpha_weight * focal_weight * bce
return loss
@staticmethod
def backward(ctx, grad_output):
predictions, targets = ctx.saved_tensors
gamma = ctx.gamma
alpha = ctx.alpha
# Compute focal loss gradient
pt = torch.where(targets == 1, predictions, 1 - predictions)
focal_weight = (1 - pt) ** gamma
# Alpha weight
alpha_weight = torch.where(targets == 1, alpha, 1 - alpha)
# Term 1: gradient from BCE
term1 = torch.where(targets == 1, -1/predictions, 1/(1 - predictions))
# Term 2: gradient from focal weight
term2 = gamma * (1 - pt) ** (gamma - 1) * (-1) * torch.where(targets == 1, 1, -1)
term2 = term2 * (-targets * torch.log(predictions) - (1 - targets) * torch.log(1 - predictions))
# Combined gradient
grad_input = alpha_weight * (focal_weight * term1 + term2) * grad_output
# Return gradients for each input (None for parameters)
return grad_input, None, None, None
# To use the custom function:
focal_loss = FocalLoss.apply
Here's how to use the focal loss in a binary classification scenario:
# Create dummy predictions and targets
predictions = torch.sigmoid(torch.randn(5, requires_grad=True))
targets = torch.tensor([0., 1., 0., 1., 0.])
print(f"Predictions: {predictions}")
print(f"Targets: {targets}")
# Apply focal loss
loss = focal_loss(predictions, targets)
print(f"Focal loss: {loss}")
# Compute gradients
loss.sum().backward()
print(f"Gradients: {predictions.grad}")
# Compare with a simple BCE loss
bce_loss = torch.nn.functional.binary_cross_entropy(predictions, targets)
print(f"BCE loss: {bce_loss}")
Example output:
Predictions: tensor([0.4228, 0.6925, 0.5097, 0.5529, 0.3885], requires_grad=True)
Targets: tensor([0., 1., 0., 1., 0.])
Focal loss: tensor([0.2948, 0.1244, 0.3537, 0.2422, 0.2384], grad_fn=<FocalLossBackward>)
Gradients: tensor([ 0.3956, -0.2131, 0.5537, -0.4307, 0.3133])
BCE loss: tensor(0.5012)
Notice how the focal loss gives different weights to easy vs. hard examples.
Performance Considerations
Custom autograd functions can also be used for performance optimization. Here's an example that demonstrates a batch matrix multiplication with custom gradient computation:
class FastBatchMatMul(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B):
ctx.save_for_backward(A, B)
return torch.bmm(A, B)
@staticmethod
def backward(ctx, grad_output):
A, B = ctx.saved_tensors
# Gradient with respect to A
grad_A = torch.bmm(grad_output, B.transpose(1, 2))
# Gradient with respect to B
grad_B = torch.bmm(A.transpose(1, 2), grad_output)
return grad_A, grad_B
# To use the custom function:
fast_batch_matmul = FastBatchMatMul.apply
Let's compare the performance:
import time
# Create batch matrices
batch_size = 1000
matrix_size = 100
A = torch.randn(batch_size, matrix_size, 50, requires_grad=True)
B = torch.randn(batch_size, 50, matrix_size, requires_grad=True)
# Time the custom implementation
start = time.time()
C_custom = fast_batch_matmul(A, B)
C_custom.sum().backward()
custom_time = time.time() - start
# Reset gradients
A.grad = None
B.grad = None
# Time the standard implementation
start = time.time()
C_standard = torch.bmm(A, B)
C_standard.sum().backward()
standard_time = time.time() - start
print(f"Custom implementation time: {custom_time:.4f} seconds")
print(f"Standard implementation time: {standard_time:.4f} seconds")
print(f"Same result: {torch.allclose(C_custom, C_standard)}")
In some cases, the custom implementation might be faster because you have more control over the computation.
Summary
In this tutorial, you've learned how to create custom autograd functions in PyTorch by extending the torch.autograd.Function
class. We covered:
- The basic structure of a custom autograd function
- How to implement forward and backward passes
- Simple examples like custom ReLU and ELU
- Numerically stable computations with custom autograd
- A real-world application with focal loss
- Performance optimizations using custom gradient calculations
Custom autograd functions give you the flexibility to define your own operations while still benefiting from PyTorch's automatic differentiation capabilities. They are particularly useful for:
- Implementing operations not available in PyTorch
- Ensuring numerical stability in complex computations
- Optimizing performance for specific use cases
- Implementing custom loss functions
Exercises
To reinforce your understanding, try these exercises:
- Implement a custom softmax function with numerical stability
- Create a custom layer normalization function
- Implement a custom attention mechanism
- Create a custom optimization operation that fuses multiple operations for better performance
Additional Resources
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)