Skip to main content

PyTorch Custom Autograd

Introduction

PyTorch's autograd system is a powerful engine for automatic differentiation that handles the calculation of gradients for you. While PyTorch provides a wide range of built-in operations with pre-defined gradients, sometimes you might need to:

  • Implement a custom operation not available in PyTorch
  • Optimize a specific computation for better performance
  • Define a more numerically stable gradient calculation
  • Create operations with custom behaviors during the backward pass

In this tutorial, we'll learn how to extend PyTorch's autograd system by creating custom autograd functions. This allows you to seamlessly integrate your own operations with PyTorch's automatic differentiation capabilities.

Understanding Custom Autograd Functions

PyTorch allows you to define custom operations by extending the torch.autograd.Function class. There are two key methods to implement:

  1. forward(): Defines the computation performed in the forward pass
  2. backward(): Defines how gradients are computed during the backward pass

Let's start with a simple example to understand the structure.

Basic Structure of a Custom Autograd Function

Here's a template for creating a custom autograd function:

python
import torch

class MyCustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2, ...):
# Perform the operation
output = ...

# Save variables needed for backward
ctx.save_for_backward(input1, input2, ...)
ctx.other_data = ... # Save non-tensor data if needed

return output

@staticmethod
def backward(ctx, grad_output):
# Retrieve saved tensors and data
input1, input2, ... = ctx.saved_tensors
other_data = ctx.other_data

# Calculate gradients
grad_input1 = ...
grad_input2 = ...

# Return gradients for each input in the same order as forward
return grad_input1, grad_input2, ...

Let's break down what's happening:

  • The ctx parameter is a context object that is used to store information for the backward pass
  • save_for_backward() stores tensors that will be needed during the backward pass
  • backward() must return the same number of gradients as there are inputs to forward()
  • Each gradient should have the same shape as its corresponding input

Simple Example: Custom ReLU Function

Let's implement a custom ReLU (Rectified Linear Unit) function to get a better understanding:

python
import torch

class CustomReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor):
ctx.save_for_backward(input_tensor)
return input_tensor.clamp(min=0)

@staticmethod
def backward(ctx, grad_output):
input_tensor, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input_tensor < 0] = 0
return grad_input

# To use the custom function:
custom_relu = CustomReLU.apply

Let's see how to use this function:

python
# Create an input tensor that requires gradients
x = torch.randn(5, requires_grad=True)
print(f"Input: {x}")

# Apply our custom ReLU
y = custom_relu(x)
print(f"Output: {y}")

# Compute gradients
y.sum().backward()
print(f"Gradient of x: {x.grad}")

Example output:

Input: tensor([-1.2032,  0.8964,  0.0677, -0.5543,  1.1631], requires_grad=True)
Output: tensor([0.0000, 0.8964, 0.0677, 0.0000, 1.1631], grad_fn=<CustomReLUBackward>)
Gradient of x: tensor([0., 1., 1., 0., 1.])

Notice how the gradient is 1 for positive inputs and 0 for negative inputs, which is exactly the derivative of the ReLU function!

More Complex Example: Custom Exponential Linear Unit (ELU)

Let's implement a more complex function, the Exponential Linear Unit (ELU):

python
class CustomELU(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, alpha=1.0):
ctx.alpha = alpha
ctx.save_for_backward(input_tensor)

output = input_tensor.clone()
negative_indices = input_tensor < 0
output[negative_indices] = alpha * (torch.exp(input_tensor[negative_indices]) - 1)

return output

@staticmethod
def backward(ctx, grad_output):
input_tensor, = ctx.saved_tensors
alpha = ctx.alpha

grad_input = grad_output.clone()
negative_indices = input_tensor < 0
grad_input[negative_indices] = grad_output[negative_indices] * (
alpha * torch.exp(input_tensor[negative_indices])
)

# We return None for the gradient of alpha since it's a parameter, not an input
return grad_input, None

# To use the custom function:
custom_elu = CustomELU.apply

Let's use our custom ELU function:

python
x = torch.randn(5, requires_grad=True)
alpha = 0.5

print(f"Input: {x}")

# Apply our custom ELU
y = custom_elu(x, alpha)
print(f"Custom ELU output: {y}")

# Compare with PyTorch's built-in ELU
z = torch.nn.functional.elu(x, alpha)
print(f"PyTorch's ELU output: {z}")

# Check if they're the same
print(f"Outputs match: {torch.allclose(y, z)}")

# Compute gradients
y.sum().backward()
print(f"Gradient of x: {x.grad}")

Example output:

Input: tensor([-0.7526,  0.9517, -1.2040,  0.8450, -0.4021], requires_grad=True)
Custom ELU output: tensor([-0.2642, 0.9517, -0.3496, 0.8450, -0.1655], grad_fn=<CustomELUBackward>)
PyTorch's ELU output: tensor([-0.2642, 0.9517, -0.3496, 0.8450, -0.1655], grad_fn=<EluBackward0>)
Outputs match: True
Gradient of x: tensor([0.2359, 1.0000, 0.1503, 1.0000, 0.3345])

Custom Autograd for Numerical Stability

One key advantage of custom autograd functions is implementing numerically stable operations. Let's see an example of a log-sum-exp operation:

python
class LogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, dim=0):
# For numerical stability, we subtract the max value before exponentiating
max_val, _ = torch.max(input_tensor, dim=dim, keepdim=True)
output = max_val + (input_tensor - max_val).exp().sum(dim=dim, keepdim=True).log()

# If keepdim is False, squeeze the reduced dimension
if not isinstance(dim, tuple):
dim = (dim,)
ctx.dim = dim
ctx.save_for_backward(input_tensor, output)

return output

@staticmethod
def backward(ctx, grad_output):
input_tensor, output = ctx.saved_tensors
dim = ctx.dim

# Compute gradient
grad_input = grad_output * torch.exp(input_tensor - output)

# Return None for the gradient of dim since it's not a tensor
return grad_input, None

# To use the custom function:
log_sum_exp = LogSumExp.apply

Here's how to use it:

python
# Create a tensor with some large values
x = torch.tensor([[1000., 1000.], [1000., 1000.]], requires_grad=True)

# Using our custom log-sum-exp
result_custom = log_sum_exp(x, dim=1)
print(f"Custom LogSumExp: {result_custom}")

# Using a naive implementation (can cause numerical overflow)
try:
result_naive = torch.log(torch.exp(x).sum(dim=1, keepdim=True))
print(f"Naive implementation: {result_naive}")
except RuntimeError as e:
print(f"Error with naive implementation: {e}")

# Test backpropagation
result_custom.sum().backward()
print(f"Gradient of x: {x.grad}")

Example output:

Custom LogSumExp: tensor([[1000.6931],
[1000.6931]], grad_fn=<LogSumExpBackward>)
Error with naive implementation: RuntimeError: inf values detected
Gradient of x: tensor([[0.5000, 0.5000],
[0.5000, 0.5000]])

Our custom implementation can handle large values without numerical overflow, while the naive approach fails.

Real-World Application: Custom Loss Function

Let's implement a custom focal loss function, which is commonly used for imbalanced classification problems:

python
class FocalLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, predictions, targets, gamma=2.0, alpha=0.25):
"""
Focal loss for binary classification
predictions: sigmoid outputs
targets: binary labels (0 or 1)
gamma: focusing parameter
alpha: balancing parameter
"""
ctx.gamma = gamma
ctx.alpha = alpha
ctx.save_for_backward(predictions, targets)

# Calculate binary cross entropy
bce = -targets * torch.log(predictions) - (1 - targets) * torch.log(1 - predictions)

# Apply focal loss formula
pt = torch.where(targets == 1, predictions, 1 - predictions)
focal_weight = (1 - pt) ** gamma

# Apply alpha weighting
alpha_weight = torch.where(targets == 1, alpha, 1 - alpha)

loss = alpha_weight * focal_weight * bce

return loss

@staticmethod
def backward(ctx, grad_output):
predictions, targets = ctx.saved_tensors
gamma = ctx.gamma
alpha = ctx.alpha

# Compute focal loss gradient
pt = torch.where(targets == 1, predictions, 1 - predictions)
focal_weight = (1 - pt) ** gamma

# Alpha weight
alpha_weight = torch.where(targets == 1, alpha, 1 - alpha)

# Term 1: gradient from BCE
term1 = torch.where(targets == 1, -1/predictions, 1/(1 - predictions))

# Term 2: gradient from focal weight
term2 = gamma * (1 - pt) ** (gamma - 1) * (-1) * torch.where(targets == 1, 1, -1)
term2 = term2 * (-targets * torch.log(predictions) - (1 - targets) * torch.log(1 - predictions))

# Combined gradient
grad_input = alpha_weight * (focal_weight * term1 + term2) * grad_output

# Return gradients for each input (None for parameters)
return grad_input, None, None, None

# To use the custom function:
focal_loss = FocalLoss.apply

Here's how to use the focal loss in a binary classification scenario:

python
# Create dummy predictions and targets
predictions = torch.sigmoid(torch.randn(5, requires_grad=True))
targets = torch.tensor([0., 1., 0., 1., 0.])

print(f"Predictions: {predictions}")
print(f"Targets: {targets}")

# Apply focal loss
loss = focal_loss(predictions, targets)
print(f"Focal loss: {loss}")

# Compute gradients
loss.sum().backward()
print(f"Gradients: {predictions.grad}")

# Compare with a simple BCE loss
bce_loss = torch.nn.functional.binary_cross_entropy(predictions, targets)
print(f"BCE loss: {bce_loss}")

Example output:

Predictions: tensor([0.4228, 0.6925, 0.5097, 0.5529, 0.3885], requires_grad=True)
Targets: tensor([0., 1., 0., 1., 0.])
Focal loss: tensor([0.2948, 0.1244, 0.3537, 0.2422, 0.2384], grad_fn=<FocalLossBackward>)
Gradients: tensor([ 0.3956, -0.2131, 0.5537, -0.4307, 0.3133])
BCE loss: tensor(0.5012)

Notice how the focal loss gives different weights to easy vs. hard examples.

Performance Considerations

Custom autograd functions can also be used for performance optimization. Here's an example that demonstrates a batch matrix multiplication with custom gradient computation:

python
class FastBatchMatMul(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B):
ctx.save_for_backward(A, B)
return torch.bmm(A, B)

@staticmethod
def backward(ctx, grad_output):
A, B = ctx.saved_tensors

# Gradient with respect to A
grad_A = torch.bmm(grad_output, B.transpose(1, 2))

# Gradient with respect to B
grad_B = torch.bmm(A.transpose(1, 2), grad_output)

return grad_A, grad_B

# To use the custom function:
fast_batch_matmul = FastBatchMatMul.apply

Let's compare the performance:

python
import time

# Create batch matrices
batch_size = 1000
matrix_size = 100
A = torch.randn(batch_size, matrix_size, 50, requires_grad=True)
B = torch.randn(batch_size, 50, matrix_size, requires_grad=True)

# Time the custom implementation
start = time.time()
C_custom = fast_batch_matmul(A, B)
C_custom.sum().backward()
custom_time = time.time() - start

# Reset gradients
A.grad = None
B.grad = None

# Time the standard implementation
start = time.time()
C_standard = torch.bmm(A, B)
C_standard.sum().backward()
standard_time = time.time() - start

print(f"Custom implementation time: {custom_time:.4f} seconds")
print(f"Standard implementation time: {standard_time:.4f} seconds")
print(f"Same result: {torch.allclose(C_custom, C_standard)}")

In some cases, the custom implementation might be faster because you have more control over the computation.

Summary

In this tutorial, you've learned how to create custom autograd functions in PyTorch by extending the torch.autograd.Function class. We covered:

  1. The basic structure of a custom autograd function
  2. How to implement forward and backward passes
  3. Simple examples like custom ReLU and ELU
  4. Numerically stable computations with custom autograd
  5. A real-world application with focal loss
  6. Performance optimizations using custom gradient calculations

Custom autograd functions give you the flexibility to define your own operations while still benefiting from PyTorch's automatic differentiation capabilities. They are particularly useful for:

  • Implementing operations not available in PyTorch
  • Ensuring numerical stability in complex computations
  • Optimizing performance for specific use cases
  • Implementing custom loss functions

Exercises

To reinforce your understanding, try these exercises:

  1. Implement a custom softmax function with numerical stability
  2. Create a custom layer normalization function
  3. Implement a custom attention mechanism
  4. Create a custom optimization operation that fuses multiple operations for better performance

Additional Resources



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)