TensorFlow @tf.function
Introduction
TensorFlow's @tf.function
decorator is one of the most powerful features for improving performance in TensorFlow 2.x applications. It bridges the gap between eager execution (which is intuitive and Pythonic) and graph execution (which is more efficient and deployable).
In this tutorial, we'll explore what @tf.function
is, why it's important, and how to use it effectively in your TensorFlow code. By the end, you'll understand how to leverage this feature to speed up your machine learning models without sacrificing readability or debuggability.
Understanding TensorFlow Execution Modes
Before diving into @tf.function
, let's understand the two execution modes in TensorFlow:
Eager Execution
import tensorflow as tf
# Eager execution is enabled by default in TensorFlow 2.x
a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 6], [7, 8]])
c = tf.matmul(a, b)
print(c)
Output:
tf.Tensor(
[[19 22]
[43 50]], shape=(2, 2), dtype=int32)
In eager execution:
- Operations are evaluated immediately
- Values can be inspected right away
- Standard Python debugging tools work
- It's more intuitive for beginners
Graph Execution
Graph execution is the original TensorFlow execution mode where computations are defined as a graph before being executed. This provides several advantages:
- Better performance through optimization
- Easy deployment
- Hardware-specific optimizations
- Automatic differentiation
The @tf.function
decorator lets us convert Python functions to TensorFlow graphs, getting the best of both worlds.
Basic Usage of @tf.function
The simplest way to use @tf.function
is to add it as a decorator to your function:
import tensorflow as tf
import time
# Without @tf.function - uses eager execution
def eager_function(x):
return tf.reduce_sum(tf.square(x))
# With @tf.function - converts to graph execution
@tf.function
def graph_function(x):
return tf.reduce_sum(tf.square(x))
# Create some data
data = tf.random.normal([1000, 1000])
# Compare performance
start = time.time()
for i in range(100):
eager_function(data)
print(f"Eager execution time: {time.time() - start:.4f}s")
start = time.time()
for i in range(100):
graph_function(data)
print(f"Graph execution time: {time.time() - start:.4f}s")
Typical Output:
Eager execution time: 0.2813s
Graph execution time: 0.0562s
The performance improvement can be significant, especially for more complex functions and larger datasets!
How @tf.function Works
When you apply @tf.function
to a Python function:
-
The first time the function is called, TensorFlow:
- Traces the function to create a graph
- Optimizes the graph
- Caches the compiled graph for future use
-
For subsequent calls:
- If inputs are similar (same shapes and dtypes), TensorFlow uses the cached graph
- If inputs are different in shape or type, TensorFlow may re-trace the function
Let's see this in action:
@tf.function
def traced_function(x):
print(f"Tracing for input: {x}")
return tf.reduce_sum(x)
# First call - traces the function
print("First call:")
traced_function(tf.constant([1, 2, 3]))
# Second call with same input shape - uses cached graph
print("\nSecond call with same shape:")
traced_function(tf.constant([4, 5, 6]))
# Call with different shape - re-traces the function
print("\nCall with different shape:")
traced_function(tf.constant([[1, 2], [3, 4]]))
Output:
First call:
Tracing for input: tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor(6, shape=(), dtype=int32)
Second call with same shape:
tf.Tensor(15, shape=(), dtype=int32)
Call with different shape:
Tracing for input: tf.Tensor([[1 2]
[3 4]], shape=(2, 2), dtype=int32)
tf.Tensor(10, shape=(), dtype=int32)
Notice that "Tracing for input" only prints when TensorFlow needs to re-trace the function!
Advanced Features of @tf.function
Using Input Signatures
To avoid unnecessary retracing, you can specify input signatures:
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.int32)])
def specific_function(x):
print("Tracing function!")
return tf.reduce_sum(x)
# First call - traces the function
print("First call:")
specific_function(tf.constant([1, 2, 3], dtype=tf.int32))
# Second call with different values but same shape/dtype - uses cached graph
print("\nSecond call:")
specific_function(tf.constant([4, 5, 6, 7], dtype=tf.int32))
# This would cause an error due to wrong dtype:
# specific_function(tf.constant([1.0, 2.0], dtype=tf.float32))
Output:
First call:
Tracing function!
tf.Tensor(6, shape=(), dtype=int32)
Second call:
tf.Tensor(22, shape=(), dtype=int32)
Function Polymorphism
TensorFlow functions can handle different types of inputs by creating separate graphs:
@tf.function
def polymorphic_function(a, b):
print(f"Tracing with a: {a}, b: {b}")
return a + b
# Trace for int32
polymorphic_function(tf.constant(1), tf.constant(1))
# Trace for float32
polymorphic_function(tf.constant(1.0), tf.constant(2.0))
# No retrace (reuses int32 graph)
polymorphic_function(tf.constant(3), tf.constant(4))
Output:
Tracing with a: tf.Tensor(1, shape=(), dtype=int32), b: tf.Tensor(1, shape=(), dtype=int32)
Tracing with a: tf.Tensor(1.0, shape=(), dtype=float32), b: tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(7, shape=(), dtype=int32)
Python Operations Inside @tf.function
Not all Python operations can be converted to TensorFlow graph operations. Here are some important considerations:
What Works Well
- TensorFlow operations (
tf.*
) - Control flow based on TensorFlow values (
tf.cond
,tf.while_loop
) - Basic Python operations on TensorFlow values
What Doesn't Work or Needs Special Handling
- Side effects like printing - Use
tf.print
instead ofprint
for reliable printing:
@tf.function
def print_example(x):
# This may not print or only print during tracing
print("Regular print:", x)
# This will print during execution
tf.print("TF print:", x)
return x
print_example(tf.constant([1, 2, 3]))
- Python state modifications - Be careful with global variables:
counter = 0
@tf.function
def increment_counter(x):
global counter
counter += 1 # This happens only during tracing!
return x + counter
# First call
print(increment_counter(tf.constant(1)))
print(f"Counter: {counter}")
# Second call
print(increment_counter(tf.constant(1)))
print(f"Counter: {counter}")
Output:
tf.Tensor(2, shape=(), dtype=int32)
Counter: 1
tf.Tensor(2, shape=(), dtype=int32)
Counter: 1 # Notice counter wasn't incremented again!
- Python collections - Use TensorFlow collections instead:
@tf.function
def use_tf_collections(x):
# Instead of Python list
# bad: result = []
# Use TensorFlow list
result = tf.TensorArray(tf.int32, size=3)
for i in range(3):
result = result.write(i, x + i)
return result.stack()
print(use_tf_collections(tf.constant(5)))
Output:
tf.Tensor([5 6 7], shape=(3,), dtype=int32)
Real-world Example: Training Loop with @tf.function
Let's put everything together and create an efficient training loop with @tf.function
:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
import time
# Create a simple model
model = Sequential([
Dense(128, activation='relu', input_shape=(784,)),
Dense(64, activation='relu'),
Dense(10, activation='softmax')
])
# Compile the model
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Define training step without @tf.function
def train_step_eager(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = model.loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
# Define training step with @tf.function
@tf.function
def train_step_graph(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = model.loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
# Generate dummy data
batch_size = 64
train_images = tf.random.normal([batch_size, 784])
train_labels = tf.random.uniform([batch_size], maxval=10, dtype=tf.int64)
# Compare performance
# Warm-up
_ = train_step_eager(train_images, train_labels)
_ = train_step_graph(train_images, train_labels)
# Eager execution timing
start_time = time.time()
for _ in range(100):
loss = train_step_eager(train_images, train_labels)
eager_time = time.time() - start_time
# Graph execution timing
start_time = time.time()
for _ in range(100):
loss = train_step_graph(train_images, train_labels)
graph_time = time.time() - start_time
print(f"Eager execution time: {eager_time:.4f}s")
print(f"Graph execution time: {graph_time:.4f}s")
print(f"Speedup: {eager_time / graph_time:.2f}x")
Typical Output:
Eager execution time: 0.4873s
Graph execution time: 0.1214s
Speedup: 4.01x
Best Practices for Using @tf.function
-
Keep functions pure: Avoid side effects and non-TensorFlow operations.
-
Use input signatures for functions with consistent input shapes.
-
Be aware of tracing costs: Complex functions take longer to trace.
-
Identify the bottleneck: Only apply
@tf.function
where needed, usually to computationally intensive parts. -
Move data preprocessing outside
@tf.function
when possible. -
Debug in eager mode first, then apply
@tf.function
. -
Be careful with Python constructs inside
@tf.function
decorated functions. -
Check concrete functions to understand what's happening:
@tf.function
def my_function(x):
return x * x
# See all concrete functions
print(my_function.pretty_printed_concrete_signatures())
# Call the function with different types
my_function(tf.constant(2))
my_function(tf.constant(2.0))
# Check again after more concrete functions are created
print(my_function.pretty_printed_concrete_signatures())
Common Pitfalls and Solutions
1. Different behavior between eager and graph modes:
# This works differently in eager vs graph mode
def buggy_function():
current_time = tf.timestamp()
return current_time
@tf.function
def graph_buggy_function():
return buggy_function() # Time captured during tracing!
# Fix: Use tf operations directly inside @tf.function
@tf.function
def fixed_function():
current_time = tf.timestamp()
return current_time
2. Python side effects not captured:
values = []
@tf.function
def append_to_list(x):
values.append(x) # This happens only during tracing
return x
# Fix: Use TensorFlow variables for state
state = tf.Variable(0)
@tf.function
def update_state(x):
state.assign_add(x) # This works properly in the graph
return state
3. Iterations over Python collections:
@tf.function
def bad_iteration():
total = tf.constant(0)
for i in range(10): # This unrolls during tracing
total += i
return total
# Fix: Use tf.range instead
@tf.function
def good_iteration():
total = tf.constant(0)
for i in tf.range(10): # This becomes a proper graph op
total += i
return total
When Not to Use @tf.function
While @tf.function
can provide significant performance benefits, it's not always necessary or beneficial:
- During development and debugging - Work in eager mode first
- For simple, non-repeated operations - The tracing overhead may exceed gains
- When operations rely heavily on Python features - The conversion may be complex
- For operations with constantly changing shapes - May cause excessive retracing
Summary
The @tf.function
decorator is a powerful tool for improving the performance of TensorFlow code by converting Python functions to optimized graph code. It bridges eager execution (which is easy to use and debug) with graph execution (which is more efficient and deployable).
Key takeaways:
@tf.function
traces Python functions and converts them to TensorFlow graphs- Functions are traced the first time they're called or when input signatures change
- Graph execution can be significantly faster than eager execution
- Be aware of Python operations that don't translate well to graphs
- Use
input_signature
to control retracing behavior - Debug in eager mode before applying
@tf.function
Additional Resources
- Official TensorFlow Guide on tf.function
- AutoGraph reference documentation
- TensorFlow Performance Guide
Exercises
-
Create a function that computes the Fibonacci sequence using
@tf.function
and compare its performance with an eager version. -
Implement a custom training loop for a simple neural network that classifies MNIST digits, using
@tf.function
. -
Experiment with different input signatures and observe how they affect tracing behavior.
-
Debug a function with Python operations that don't work well inside
@tf.function
and fix it. -
Profile your own TensorFlow code to identify where
@tf.function
can provide the biggest performance gains.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)