Skip to main content

TensorFlow Common Patterns

Introduction

As you dive deeper into building machine learning models with TensorFlow, you'll notice certain patterns appearing repeatedly in well-written code. These patterns represent proven approaches to solve common problems efficiently. Understanding these patterns will help you write more readable, maintainable, and efficient TensorFlow code.

This guide explores the most common patterns used in TensorFlow applications, from basic model building to advanced techniques. Whether you're just starting with TensorFlow or looking to refine your skills, these patterns will serve as valuable tools in your machine learning journey.

1. Model Creation Patterns

Model Subclassing Pattern

Creating models by subclassing tf.keras.Model provides flexibility for complex architectures.

python
import tensorflow as tf

class SimpleModel(tf.keras.Model):
def __init__(self):
super(SimpleModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(64, activation='relu')
self.dense2 = tf.keras.layers.Dense(10, activation='softmax')

def call(self, inputs, training=False):
x = self.dense1(inputs)
return self.dense2(x)

# Using the model
model = SimpleModel()

Functional API Pattern

The functional API is great for models with multiple inputs, outputs, or shared layers.

python
import tensorflow as tf

# Define inputs
inputs = tf.keras.Input(shape=(784,))

# Define the network
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

# Create the model
model = tf.keras.Model(inputs=inputs, outputs=outputs)

Sequential Model Pattern

For simple linear stacks of layers, the Sequential API provides a clean approach.

python
import tensorflow as tf

model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])

2. Input Pipeline Patterns

tf.data Pattern

Efficiently loading and preprocessing data with tf.data:

python
import tensorflow as tf

# Create a dataset from tensors
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# Apply transformations
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

# Alternatively, use method chaining
dataset = tf.data.Dataset.from_tensor_slices((features, labels))\
.shuffle(1000)\
.batch(32)\
.prefetch(tf.data.AUTOTUNE)

# Iterate over the dataset
for x, y in dataset:
# Training step here
pass

Output: Each iteration provides a batch of features and labels ready for training.

Data Augmentation Pattern

Applying data augmentation to prevent overfitting:

python
import tensorflow as tf

def augment(image, label):
# Flip the image horizontally
image = tf.image.random_flip_left_right(image)
# Adjust brightness slightly
image = tf.image.random_brightness(image, 0.2)
return image, label

# Apply augmentation to a dataset
train_dataset = train_dataset.map(augment)

3. Custom Training Loop Patterns

GradientTape Training Pattern

Using tf.GradientTape for custom training:

python
import tensorflow as tf

model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
loss = loss_fn(labels, predictions)

# Get gradients and update weights
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

return loss

# Training loop
for epoch in range(3):
for x, y in dataset:
loss = train_step(x, y)
print(f"Epoch {epoch+1}, Loss: {loss.numpy()}")

Output:

Epoch 1, Loss: 0.3241
Epoch 2, Loss: 0.2187
Epoch 3, Loss: 0.1923

Custom Metrics Tracking Pattern

Track metrics during training:

python
# Initialize metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_function(labels, predictions)

gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

# Update metrics
train_loss.update_state(loss)
train_accuracy.update_state(labels, predictions)

return loss

# Inside training loop
for epoch in range(epochs):
# Reset metrics for each epoch
train_loss.reset_states()
train_accuracy.reset_states()

for images, labels in train_dataset:
train_step(images, labels)

print(
f'Epoch {epoch + 1}, '
f'Loss: {train_loss.result()}, '
f'Accuracy: {train_accuracy.result() * 100}'
)

4. Model Deployment Patterns

SavedModel Pattern

Save and load models for deployment:

python
# Save a model
model.save('my_model')

# Load the model
loaded_model = tf.keras.models.load_model('my_model')

# Make predictions
predictions = loaded_model.predict(new_data)

TensorFlow Lite Conversion Pattern

Convert models for mobile and edge devices:

python
# Convert to TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_saved_model('my_model')
tflite_model = converter.convert()

# Save the TF Lite model
with open('model.tflite', 'wb') as f:
f.write(tflite_model)

# Load and use the TF Lite model
interpreter = tf.lite.Interpreter(model_path="model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test with an input array
interpreter.set_tensor(input_details[0]['index'], test_image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

5. Debugging and Optimization Patterns

Eager Execution Debugging Pattern

Debug your model by inspecting tensors immediately:

python
# Turn off @tf.function decorator during debugging
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_function(labels, predictions)

print("Loss shape:", loss.shape) # Immediate feedback
print("First few loss values:", loss[0:3])

gradients = tape.gradient(loss, model.trainable_variables)
print("Gradient of first layer:", gradients[0][0][:5]) # View gradients

optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss

TensorBoard Integration Pattern

Visualize training progress with TensorBoard:

python
import datetime

# Set up TensorBoard logs directory
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir, histogram_freq=1)

# With model.fit
model.fit(
train_dataset,
epochs=5,
validation_data=val_dataset,
callbacks=[tensorboard_callback]
)

# With custom training loop
summary_writer = tf.summary.create_file_writer(log_dir)

for epoch in range(5):
for step, (images, labels) in enumerate(dataset):
loss = train_step(images, labels)

# Write to TensorBoard
with summary_writer.as_default():
tf.summary.scalar('loss', loss, step=epoch * steps_per_epoch + step)
tf.summary.scalar('accuracy', train_accuracy.result(),
step=epoch * steps_per_epoch + step)

6. Real-World Application Example

Let's put these patterns together in a complete image classification example:

python
import tensorflow as tf
import tensorflow_datasets as tfds

# Load dataset
(train_ds, val_ds), ds_info = tfds.load(
'cifar10',
split=['train', 'test'],
as_supervised=True,
with_info=True
)

# Data preprocessing pattern
def preprocess(image, label):
# Normalize pixel values
image = tf.cast(image, tf.float32) / 255.0
return image, label

# Data augmentation pattern
def augment(image, label):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, 0.1)
image = tf.image.random_contrast(image, 0.8, 1.2)
return image, label

# Input pipeline pattern
batch_size = 32

train_ds = train_ds.map(preprocess).map(augment)
train_ds = train_ds.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.map(preprocess).batch(batch_size).prefetch(tf.data.AUTOTUNE)

# Model creation pattern - using Functional API
inputs = tf.keras.Input(shape=(32, 32, 3))

x = tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu')(x)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu')(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
outputs = tf.keras.layers.Dense(10)(outputs)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

# Training pattern
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Metrics tracking pattern
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')

# Custom training loop pattern
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_fn(labels, predictions)

gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

train_loss.update_state(loss)
train_accuracy.update_state(labels, predictions)
return loss

@tf.function
def val_step(images, labels):
predictions = model(images, training=False)
v_loss = loss_fn(labels, predictions)

val_loss.update_state(v_loss)
val_accuracy.update_state(labels, predictions)

# TensorBoard pattern
log_dir = "logs/cifar10/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
summary_writer = tf.summary.create_file_writer(log_dir)

# Training loop
epochs = 10

for epoch in range(epochs):
# Reset metrics
train_loss.reset_states()
train_accuracy.reset_states()
val_loss.reset_states()
val_accuracy.reset_states()

# Training
for images, labels in train_ds:
train_step(images, labels)

# Validation
for images, labels in val_ds:
val_step(images, labels)

# TensorBoard updates
with summary_writer.as_default():
tf.summary.scalar('train_loss', train_loss.result(), step=epoch)
tf.summary.scalar('train_accuracy', train_accuracy.result(), step=epoch)
tf.summary.scalar('val_loss', val_loss.result(), step=epoch)
tf.summary.scalar('val_accuracy', val_accuracy.result(), step=epoch)

# Print metrics
print(
f'Epoch {epoch+1}, '
f'Loss: {train_loss.result():.3f}, '
f'Accuracy: {train_accuracy.result()*100:.2f}%, '
f'Val Loss: {val_loss.result():.3f}, '
f'Val Accuracy: {val_accuracy.result()*100:.2f}%'
)

# SavedModel pattern
model.save('cifar10_model')

This example demonstrates how to integrate multiple TensorFlow patterns into a cohesive application.

Summary

In this guide, we've explored common TensorFlow patterns that will help you write more efficient, maintainable, and effective code:

  1. Model Creation Patterns - Different ways to create models (Sequential, Functional, Subclassing)
  2. Input Pipeline Patterns - Efficient data loading and preprocessing with tf.data
  3. Custom Training Loop Patterns - Fine-grained control over the training process
  4. Model Deployment Patterns - Saving and converting models for deployment
  5. Debugging and Optimization Patterns - Tools for monitoring and improving performance
  6. Real-World Application - How to combine these patterns in a complete solution

By recognizing and applying these patterns in your TensorFlow projects, you'll build better models faster and with fewer errors.

Additional Resources

Exercises

  1. Basic Pattern Practice: Convert a Sequential model to use the Functional API.
  2. Data Pipeline: Create a tf.data pipeline that loads images from disk, applies data augmentation, and batches them efficiently.
  3. Custom Training: Implement a custom training loop with learning rate scheduling.
  4. Model Deployment: Convert a trained model to TensorFlow Lite and test its inference speed.
  5. Advanced Challenge: Create a multi-input model using the Functional API that processes both image and text data simultaneously.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)