Skip to main content

TensorFlow GradientTape

Introduction

TensorFlow's GradientTape is a powerful tool for automatic differentiation that records operations for later computing derivatives. It's particularly useful when you need to implement custom training procedures, advanced optimization techniques, or any algorithm requiring gradient computation.

In this tutorial, you'll learn:

  • What automatic differentiation is and why it's important
  • How to use GradientTape to compute gradients
  • Persistent vs. non-persistent tapes
  • Building custom training loops with GradientTape
  • Advanced techniques and best practices

What is Automatic Differentiation?

Automatic differentiation is the foundation of training neural networks. It calculates the gradients (derivatives) of functions defined in code, which are essential for optimization algorithms like gradient descent.

Unlike numerical differentiation (which approximates derivatives) or symbolic differentiation (which derives mathematical formulas), automatic differentiation tracks operations and applies the chain rule to compute exact gradients efficiently.

Getting Started with GradientTape

Let's start with the basics. First, we need to import TensorFlow:

python
import tensorflow as tf

The simplest use of GradientTape is to compute the derivative of a function:

python
# Create variables
x = tf.Variable(3.0)

# Record operations with GradientTape
with tf.GradientTape() as tape:
y = x * x

# Calculate the gradient dy/dx
dy_dx = tape.gradient(y, x)
print(f"dy/dx at x = 3.0: {dy_dx.numpy()}")

Output:

dy/dx at x = 3.0: 6.0

What happened here? The GradientTape recorded all operations involving x while within the with block. When we called tape.gradient(y, x), it computed the derivative of y with respect to x.

Computing Gradients with Respect to Multiple Variables

GradientTape can compute gradients with respect to multiple variables simultaneously:

python
x = tf.Variable(2.0)
y = tf.Variable(3.0)

with tf.GradientTape() as tape:
f = x * x * y + y * y

# Get gradients df/dx and df/dy
gradients = tape.gradient(f, [x, y])
print(f"df/dx: {gradients[0].numpy()}")
print(f"df/dy: {gradients[1].numpy()}")

Output:

df/dx: 12.0
df/dy: 10.0

Persistent vs. Non-Persistent Tapes

By default, a GradientTape is non-persistent, which means it can only compute gradients once:

python
x = tf.Variable(3.0)

with tf.GradientTape() as tape:
y = x * x
z = x * x * x

# First gradient computation works
dy_dx = tape.gradient(y, x)
print(f"dy/dx: {dy_dx.numpy()}")

# This will raise an error
try:
dz_dx = tape.gradient(z, x)
except RuntimeError as e:
print(f"Error: {e}")

Output:

dy/dx: 6.0
Error: A non-persistent GradientTape can only be used to compute one set of gradients

To compute multiple gradients from the same operations, use a persistent tape:

python
x = tf.Variable(3.0)

with tf.GradientTape(persistent=True) as tape:
y = x * x
z = x * x * x

dy_dx = tape.gradient(y, x)
dz_dx = tape.gradient(z, x)

print(f"dy/dx: {dy_dx.numpy()}")
print(f"dz/dx: {dz_dx.numpy()}")

# Don't forget to delete the tape when done
del tape

Output:

dy/dx: 6.0
dz/dx: 27.0

Watching Tensors

By default, GradientTape only tracks operations on TensorFlow variables. To track operations on regular tensors, you need to "watch" them:

python
x = tf.constant(3.0)  # Not a Variable

with tf.GradientTape() as tape:
tape.watch(x) # Explicitly watch the tensor
y = x * x

dy_dx = tape.gradient(y, x)
print(f"dy/dx: {dy_dx.numpy()}")

Output:

dy/dx: 6.0

Nested GradientTapes (Higher-order Derivatives)

You can nest GradientTape contexts to compute higher-order derivatives:

python
x = tf.Variable(1.0)

with tf.GradientTape() as tape_2:
with tf.GradientTape() as tape_1:
y = x * x * x

# First derivative: dy/dx
dy_dx = tape_1.gradient(y, x)

# Second derivative: d²y/dx²
d2y_dx2 = tape_2.gradient(dy_dx, x)

print(f"dy/dx: {dy_dx.numpy()}")
print(f"d²y/dx²: {d2y_dx2.numpy()}")

Output:

dy/dx: 3.0
d²y/dx²: 6.0

Building a Custom Training Loop

One of the most common use cases for GradientTape is building custom training loops. Here's how to implement a simple linear regression model:

python
# Generate synthetic data
import numpy as np

# True parameters
true_w = 2.0
true_b = 1.0

# Generate synthetic data with some noise
x_data = np.random.randn(1000).astype(np.float32)
y_data = true_w * x_data + true_b + 0.1 * np.random.randn(1000).astype(np.float32)

# Create TensorFlow dataset
dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
dataset = dataset.batch(32)

# Initialize model parameters
w = tf.Variable(0.0)
b = tf.Variable(0.0)

# Optimization parameters
learning_rate = 0.1
epochs = 100

# Mean squared error loss function
def loss_fn(y_pred, y_true):
return tf.reduce_mean(tf.square(y_pred - y_true))

# Training step
for epoch in range(epochs):
epoch_loss = 0.0

for x_batch, y_batch in dataset:
with tf.GradientTape() as tape:
# Forward pass
y_pred = w * x_batch + b

# Compute loss
batch_loss = loss_fn(y_pred, y_batch)

# Get gradients
gradients = tape.gradient(batch_loss, [w, b])

# Update parameters using gradient descent
w.assign_sub(learning_rate * gradients[0])
b.assign_sub(learning_rate * gradients[1])

epoch_loss += batch_loss

if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {epoch_loss.numpy() / (len(x_data) / 32)}")
print(f"Parameters: w = {w.numpy()}, b = {b.numpy()}")

print("\nFinal parameters:")
print(f"w = {w.numpy()}, true_w = {true_w}")
print(f"b = {b.numpy()}, true_b = {true_b}")

Output (will vary due to randomness):

Epoch 0, Loss: 3.2021788359642028
Parameters: w = 0.47811308503150635, b = 0.26328483223915
Epoch 10, Loss: 0.11166522495925427
Parameters: w = 1.8563919067382812, b = 0.8553428053855896
...
Epoch 90, Loss: 0.01023756639789723
Parameters: w = 1.9912989139556885, b = 0.9902616143226624

Final parameters:
w = 1.9914206266403198, true_w = 2.0
b = 0.9904240965843201, true_b = 1.0

Controlling Gradient Recording

Sometimes you don't want to track all operations. You can use tf.stop_gradient() to prevent certain operations from being recorded:

python
x = tf.Variable(2.0)
y = tf.Variable(3.0)

with tf.GradientTape() as tape:
# f will be treated as a constant for gradient computation
f = tf.stop_gradient(x * y)
g = x * f

dg_dx = tape.gradient(g, x)
dg_dy = tape.gradient(g, y)

print(f"dg/dx: {dg_dx.numpy()}") # Equal to f (or x*y)
print(f"dg/dy: {dg_dy.numpy()}") # Equal to 0 (since y only appears in stop_gradient)

Output:

dg/dx: 6.0
dg/dy: None

Practical Example: Training a Neural Network

Now let's put everything together to train a simple neural network for classifying MNIST digits:

python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten
import numpy as np

# Load and preprocess the data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 # Normalize

# Convert labels to one-hot encoding
y_train = tf.one_hot(y_train, 10)
y_test = tf.one_hot(y_test, 10)

# Create training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).batch(64)

# Create model
class SimpleNN(tf.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.flatten = Flatten()
self.dense1 = Dense(128, activation='relu')
self.dense2 = Dense(10)

def __call__(self, x):
x = self.flatten(x)
x = self.dense1(x)
return self.dense2(x)

# Initialize model and optimizer
model = SimpleNN()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Loss function
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

# Training loop
@tf.function # Optional: Compile for better performance
def train_step(images, labels):
with tf.GradientTape() as tape:
# Forward pass
predictions = model(images)
loss = loss_fn(labels, predictions)

# Get gradients
gradients = tape.gradient(loss, model.trainable_variables)

# Update weights
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

return loss

# Train the model
epochs = 5

for epoch in range(epochs):
total_loss = 0
num_batches = 0

for images, labels in train_dataset:
batch_loss = train_step(images, labels)
total_loss += batch_loss
num_batches += 1

avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

# Evaluate model
@tf.function
def test_accuracy():
predictions = model(x_test)
predicted_classes = tf.argmax(predictions, axis=1)
true_classes = tf.argmax(y_test, axis=1)
accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted_classes, true_classes), tf.float32))
return accuracy

print(f"Test accuracy: {test_accuracy().numpy():.4f}")

Output (approximate):

Epoch 1/5, Loss: 0.3098
Epoch 2/5, Loss: 0.1432
Epoch 3/5, Loss: 0.1062
Epoch 4/5, Loss: 0.0859
Epoch 5/5, Loss: 0.0731
Test accuracy: 0.9742

Best Practices and Tips

  1. Memory Management: Release persistent tapes with del tape when you're done to free memory.

  2. Speed Optimization: Consider using @tf.function to compile your gradient operations for better performance.

  3. Debugging: If you get None gradients, check if:

    • Your tensors are being watched
    • Your computation actually depends on the variables
    • You're trying to differentiate through non-differentiable operations
  4. Gradient Clipping: Prevent exploding gradients by clipping:

python
with tf.GradientTape() as tape:
loss = compute_loss(model, x, y)
gradients = tape.gradient(loss, model.trainable_variables)

# Clip gradients
clipped_gradients, _ = tf.clip_by_global_norm(gradients, clip_norm=1.0)
optimizer.apply_gradients(zip(clipped_gradients, model.trainable_variables))
  1. Avoid Retracing: When using @tf.function, be careful about Python control flow that depends on inputs, as it can cause repeated retracing.

Summary

TensorFlow's GradientTape is a flexible and powerful tool for automatic differentiation. It allows you to:

  • Compute gradients for any differentiable operation
  • Build custom training loops and optimization algorithms
  • Calculate higher-order derivatives
  • Control which operations are tracked for gradient computation

Whether you're implementing advanced research algorithms or simply need more control over your training process, GradientTape provides the flexibility needed to define exactly how gradients are calculated and applied.

Additional Resources

Exercises

  1. Implement a simple neural network for binary classification using GradientTape and compare your results with Keras' built-in training.

  2. Compute the Hessian matrix (second derivatives) of a function using nested GradientTape.

  3. Implement momentum optimization from scratch using GradientTape.

  4. Create a custom layer with learnable parameters and train it using GradientTape.

  5. Experiment with gradient checkpointing to reduce memory usage during the backward pass.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)