Skip to main content

TensorFlow Custom Gradients

Introduction

TensorFlow's automatic differentiation system is powerful, but sometimes you need more control over how gradients are computed. Custom gradients allow you to override TensorFlow's default gradient behavior, which can be useful for:

  • Implementing novel optimization techniques
  • Improving numerical stability
  • Creating specialized layers with unique backpropagation behavior
  • Fixing gradient issues in complex operations
  • Implementing techniques like straight-through estimators

In this tutorial, we'll explore how to define custom gradients in TensorFlow, providing you with the knowledge to extend TensorFlow's capabilities for your specific needs.

Understanding Gradients in TensorFlow

Before diving into custom gradients, let's quickly review what gradients are in TensorFlow.

When training neural networks, TensorFlow computes gradients to update model parameters during backpropagation. For most operations, TensorFlow automatically calculates these gradients. However, there are situations where you might want to define your own gradient calculations.

Custom Gradients with tf.custom_gradient

TensorFlow provides the tf.custom_gradient decorator to define functions with custom gradients. This decorator lets you specify both the forward pass computation and how gradients should flow backward.

Basic Syntax

python
import tensorflow as tf

@tf.custom_gradient
def my_custom_function(x):
result = # your forward computation

def grad(upstream):
# custom gradient calculation
return gradient_value

return result, grad

Let's break this down:

  1. You define a function with your forward computation
  2. Within this function, you define a nested grad function that computes gradients
  3. The outer function returns both the result and the gradient function

Simple Example: Custom ReLU

Let's implement a custom ReLU activation function with a slight modification to its gradient:

python
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

@tf.custom_gradient
def custom_relu(x):
result = tf.maximum(x, 0)

def grad(dy):
# Standard ReLU gradient is 1 where x > 0, and 0 elsewhere
# We'll modify it to have a small slope (0.1) for negative values
return dy * tf.cast(x > 0, tf.float32) + 0.1 * dy * tf.cast(x <= 0, tf.float32)

return result, grad

# Let's test our custom ReLU
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0])
with tf.GradientTape() as tape:
tape.watch(x)
y = custom_relu(x)

# Get gradients
dy_dx = tape.gradient(y, x)

print("Input:", x.numpy())
print("Output:", y.numpy())
print("Gradients:", dy_dx.numpy())

Output:

Input: [-2. -1.  0.  1.  2.]
Output: [0. 0. 0. 1. 2.]
Gradients: [0.1 0.1 0.1 1. 1. ]

Notice that our custom gradient returns 0.1 for negative inputs, instead of the standard ReLU gradient which would be 0.

Visualizing our Custom ReLU

Let's visualize our custom ReLU function and its gradient:

python
def plot_function_and_gradient(function):
x_range = np.linspace(-3, 3, 100).astype(np.float32)
x_tensor = tf.constant(x_range)

with tf.GradientTape() as tape:
tape.watch(x_tensor)
y_tensor = function(x_tensor)

gradients = tape.gradient(y_tensor, x_tensor)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(x_range, y_tensor.numpy())
plt.title('Function')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(x_range, gradients.numpy())
plt.title('Gradient')
plt.grid(True)

plt.tight_layout()
plt.show()

# Compare standard ReLU and our custom ReLU
plot_function_and_gradient(tf.nn.relu)
plot_function_and_gradient(custom_relu)

This code would generate two sets of plots showing how our custom ReLU's gradient differs from the standard ReLU's gradient.

Using Custom Gradients for Numerical Stability

Custom gradients are often used to improve numerical stability. Let's implement a numerically stable log-sum-exp function:

python
@tf.custom_gradient
def log_sum_exp(x):
max_x = tf.reduce_max(x, axis=-1, keepdims=True)
y = max_x + tf.math.log(tf.reduce_sum(tf.exp(x - max_x), axis=-1, keepdims=True))

def grad(dy):
# dy has shape [..., 1]
# We need to broadcast it for each input
y_exp = tf.exp(x - max_x)
denominator = tf.reduce_sum(y_exp, axis=-1, keepdims=True)
return dy * y_exp / denominator

return y, grad

# Test the function
x = tf.constant([[1000., 1000.], [0., 10.]])
result = log_sum_exp(x)
print("Input:", x.numpy())
print("Output:", result.numpy())

Output:

Input: [[1000. 1000.]
[0. 10. ]]
Output: [[1000.6931]
[10.0000]]

This implementation prevents numerical overflow that can occur when computing exponentials of large numbers.

Gradient Clipping with Custom Gradients

Another common application is implementing custom gradient clipping:

python
@tf.custom_gradient
def clip_gradients(x, clip_value=1.0):
# Forward pass is the identity
y = tf.identity(x)

def grad(dy):
# Clip the gradients during backpropagation
return tf.clip_by_value(dy, -clip_value, clip_value)

return y, grad

# Let's test gradient clipping
x = tf.constant([1.0, 2.0, 3.0])
with tf.GradientTape() as tape:
tape.watch(x)
# Compute a function that will have large gradients
y = 10.0 * tf.square(clip_gradients(x, clip_value=0.5))

# Get gradients
gradients = tape.gradient(y, x)
print("Gradients after clipping:", gradients.numpy())

Output:

Gradients after clipping: [0.5 0.5 0.5]

The gradients have been clipped to 0.5 as specified.

Custom Gradients with Multiple Inputs

You can also define custom gradients for functions with multiple inputs:

python
@tf.custom_gradient
def custom_multiply(x, y):
result = x * y

def grad(upstream):
# For a function f(x,y) = x*y
# df/dx = y and df/dy = x
grad_x = upstream * y
grad_y = upstream * x
return grad_x, grad_y

return result, grad

# Test with multiple inputs
x = tf.constant(2.0)
y = tf.constant(3.0)

with tf.GradientTape() as tape:
tape.watch([x, y])
z = custom_multiply(x, y)

# Get gradients
dx, dy = tape.gradient(z, [x, y])
print(f"z = {z.numpy()}")
print(f"dz/dx = {dx.numpy()}")
print(f"dz/dy = {dy.numpy()}")

Output:

z = 6.0
dz/dx = 3.0
dz/dy = 2.0

Practical Application: Straight-Through Estimator

The straight-through estimator is a technique used to backpropagate through non-differentiable functions like binary step functions. Let's implement it:

python
@tf.custom_gradient
def binary_activation(x):
# Forward pass: binary step function
y = tf.cast(x > 0, tf.float32)

def grad(dy):
# Straight-through estimator: pretend the gradient is 1 in a small region
grad_estimation = tf.where(
tf.logical_and(x > -0.5, x < 0.5),
tf.ones_like(x),
tf.zeros_like(x)
)
return dy * grad_estimation

return y, grad

# Create a simple model with binary activation
class BinaryLayer(tf.keras.layers.Layer):
def __init__(self, units):
super().__init__()
self.dense = tf.keras.layers.Dense(units)

def call(self, inputs):
return binary_activation(self.dense(inputs))

# Test with a simple model
model = tf.keras.Sequential([
BinaryLayer(10),
tf.keras.layers.Dense(1)
])

# Create a small dataset
x = tf.random.normal((32, 5))
y = tf.random.normal((32, 1))

# Compile and train for one step to verify gradients flow
model.compile(optimizer='adam', loss='mse')
model.fit(x, y, epochs=1, verbose=0)

print("Model trained successfully with binary activation and straight-through estimator")

Advanced Example: Gradient Reversal Layer

Gradient reversal layers are used in domain adaptation tasks. They reverse the gradient during backpropagation while keeping the forward pass intact:

python
@tf.custom_gradient
def gradient_reversal(x):
y = tf.identity(x)

def grad(dy):
return -1.0 * dy # Reverse the gradient

return y, grad

class GradientReversalLayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()

def call(self, inputs):
return gradient_reversal(inputs)

# Example usage in a domain adaptation scenario
feature_extractor = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(64, activation='relu')
])

domain_classifier = tf.keras.Sequential([
GradientReversalLayer(),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])

# Main task classifier
task_classifier = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])

# Usage in a model
inputs = tf.keras.layers.Input(shape=(100,))
features = feature_extractor(inputs)
task_output = task_classifier(features)
domain_output = domain_classifier(features)

model = tf.keras.Model(inputs=inputs, outputs=[task_output, domain_output])

This implementation creates a domain adaptation model where the feature extractor learns to produce domain-invariant features through gradient reversal.

Performance Considerations

When implementing custom gradients, keep these performance tips in mind:

  1. Computation Reuse: Store and reuse computations needed for both forward and backward passes
  2. Avoid Redundant Operations: Minimize duplicate calculations in the gradient function
  3. Use TensorFlow Ops: Leverage TensorFlow's optimized operations when possible
  4. Test Numerically: Verify your gradient implementations using tf.test.compute_gradient

Summary

Custom gradients in TensorFlow provide a powerful way to customize the behavior of your neural networks. In this tutorial, we covered:

  • The basics of the tf.custom_gradient decorator
  • Implementing simple custom gradient functions
  • Practical applications such as numerical stability, gradient clipping, and straight-through estimators
  • Advanced use cases like gradient reversal layers
  • Performance considerations

With custom gradients, you can implement unique behaviors that aren't possible with standard TensorFlow operations, opening up new possibilities for your machine learning models.

Additional Resources

Exercises

  1. Implement a custom sigmoid activation function with scaled gradients
  2. Create a custom L1 regularization function with adjustable gradient behavior
  3. Implement a custom layer that uses the "gumbel-softmax trick" with a straight-through estimator
  4. Create a custom optimizer using custom gradients
  5. Implement a "hard tanh" function with custom gradients for stable training


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)