Skip to main content

TensorFlow Model Subclassing

In this comprehensive guide, you'll learn how to create custom neural network architectures using TensorFlow's Model Subclassing API. This flexible approach enables you to build complex model architectures with full Python control flow.

Introduction to Model Subclassing

TensorFlow provides multiple APIs for creating neural network models:

  1. Sequential API: Simple, linear stack of layers
  2. Functional API: More flexible, allows non-sequential models
  3. Model Subclassing API: The most flexible approach, offering complete control

Model Subclassing is particularly useful when you need:

  • Complex model architectures
  • Custom training logic
  • Conditional layer execution
  • Dynamic computational graphs

By subclassing the tf.keras.Model class, you can define your own forward pass logic and build truly custom neural networks.

Getting Started with Model Subclassing

Basic Structure

To create a custom model, you need to:

  1. Create a class that inherits from tf.keras.Model
  2. Initialize your layers in the __init__ method
  3. Define the forward pass in the call method

Here's a simple example:

python
import tensorflow as tf

class SimpleModel(tf.keras.Model):
def __init__(self):
super(SimpleModel, self).__init__()
# Define your layers here
self.dense1 = tf.keras.layers.Dense(64, activation='relu')
self.dense2 = tf.keras.layers.Dense(10, activation='softmax')

def call(self, inputs, training=None):
# Define your forward pass here
x = self.dense1(inputs)
return self.dense2(x)

Creating and Using the Model

Let's create an instance of our model and use it:

python
# Create model instance
model = SimpleModel()

# Compile the model
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Input shape (optional but recommended)
model.build(input_shape=(None, 784))

# Display model summary
model.summary()

Output:

Model: "simple_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) multiple 50240
_________________________________________________________________
dense_1 (Dense) multiple 650
=================================================================
Total params: 50,890
Trainable params: 50,890
Non-trainable params: 0
_________________________________________________________________

Advanced Model Subclassing Techniques

Custom Layers and Advanced Architecture

Let's build a more complex model with custom layers and branching:

python
class ComplexModel(tf.keras.Model):
def __init__(self):
super(ComplexModel, self).__init__()
# Feature extraction layers
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.conv2 = tf.keras.layers.Conv2D(64, 3, activation='relu')
self.maxpool = tf.keras.layers.MaxPooling2D(2)
self.flatten = tf.keras.layers.Flatten()

# Classification branch
self.dense1 = tf.keras.layers.Dense(128, activation='relu')
self.dropout = tf.keras.layers.Dropout(0.5)
self.classifier = tf.keras.layers.Dense(10, activation='softmax')

def call(self, inputs, training=False):
# Feature extraction
x = self.conv1(inputs)
x = self.maxpool(x)
x = self.conv2(x)
x = self.maxpool(x)
x = self.flatten(x)

# Classification branch with training-specific behavior
features = self.dense1(x)
if training:
features = self.dropout(features)

return self.classifier(features)

Notice how we can incorporate training-specific behavior like dropout using the training argument in the call method.

Adding Custom Training Logic

We can also customize the training step:

python
class CustomTrainingModel(tf.keras.Model):
def __init__(self):
super(CustomTrainingModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(64, activation='relu')
self.dense2 = tf.keras.layers.Dense(10, activation='softmax')
self.loss_tracker = tf.keras.metrics.Mean(name='loss')
self.accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)

@tf.function
def train_step(self, data):
x, y = data

with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred)

# Calculate gradients and update weights
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))

# Update metrics
self.loss_tracker.update_state(loss)
self.accuracy.update_state(y, y_pred)

# Return metrics
return {"loss": self.loss_tracker.result(), "accuracy": self.accuracy.result()}

Practical Example: Custom CNN for MNIST

Let's implement a complete example using the MNIST dataset:

python
import tensorflow as tf
from tensorflow.keras.datasets import mnist

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0

# Define our custom CNN model
class MNISTModel(tf.keras.Model):
def __init__(self):
super(MNISTModel, self).__init__()
# Convolutional layers
self.conv1 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')
self.conv2 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')
self.maxpool = tf.keras.layers.MaxPooling2D(2)
self.dropout1 = tf.keras.layers.Dropout(0.25)

# Fully connected layers
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(128, activation='relu')
self.dropout2 = tf.keras.layers.Dropout(0.5)
self.dense2 = tf.keras.layers.Dense(10, activation='softmax')

def call(self, inputs, training=False):
# Forward pass
x = self.conv1(inputs)
x = self.conv2(x)
x = self.maxpool(x)
x = self.dropout1(x) if training else x

x = self.flatten(x)
x = self.dense1(x)
x = self.dropout2(x) if training else x

return self.dense2(x)

# Create and compile the model
model = MNISTModel()
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Train the model
model.fit(
x_train, y_train,
batch_size=128,
epochs=5,
validation_split=0.1
)

# Evaluate the model
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_accuracy:.4f}")

Output:

Epoch 1/5
422/422 [==============================] - 12s 27ms/step - loss: 0.3828 - accuracy: 0.8861 - val_loss: 0.0901 - val_accuracy: 0.9727
Epoch 2/5
422/422 [==============================] - 11s 26ms/step - loss: 0.1242 - accuracy: 0.9635 - val_loss: 0.0583 - val_accuracy: 0.9817
Epoch 3/5
422/422 [==============================] - 11s 26ms/step - loss: 0.0874 - accuracy: 0.9741 - val_loss: 0.0485 - val_accuracy: 0.9842
Epoch 4/5
422/422 [==============================] - 11s 26ms/step - loss: 0.0708 - accuracy: 0.9790 - val_loss: 0.0388 - val_accuracy: 0.9872
Epoch 5/5
422/422 [==============================] - 11s 26ms/step - loss: 0.0579 - accuracy: 0.9826 - val_loss: 0.0356 - val_accuracy: 0.9883
313/313 [==============================] - 2s 6ms/step - loss: 0.0324 - accuracy: 0.9896
Test accuracy: 0.9896

Real-world Application: Transfer Learning with Custom Model

Transfer learning is a powerful technique in deep learning. Let's create a custom model that incorporates a pre-trained base:

python
class TransferLearningModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(TransferLearningModel, self).__init__()

# Use MobileNetV2 as the base model
self.base_model = tf.keras.applications.MobileNetV2(
include_top=False,
weights='imagenet',
input_shape=(224, 224, 3)
)

# Freeze the base model
self.base_model.trainable = False

# Add custom classification head
self.global_pool = tf.keras.layers.GlobalAveragePooling2D()
self.batch_norm = tf.keras.layers.BatchNormalization()
self.dense1 = tf.keras.layers.Dense(256, activation='relu')
self.dropout = tf.keras.layers.Dropout(0.5)
self.output_layer = tf.keras.layers.Dense(num_classes, activation='softmax')

def call(self, inputs, training=False):
# Pass inputs through the base model
x = self.base_model(inputs, training=False) # Base model always in inference mode

# Pass through the classification head
x = self.global_pool(x)
x = self.batch_norm(x, training=training)
x = self.dense1(x)
if training:
x = self.dropout(x)

return self.output_layer(x)

def get_config(self):
return {"num_classes": self.output_layer.units}

@classmethod
def from_config(cls, config):
return cls(**config)

To use this model for a flower classification task:

python
# Example usage for flower classification
model = TransferLearningModel(num_classes=5) # 5 flower classes

# Compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)

# Now we can train the model with a flower dataset
# model.fit(flower_dataset, epochs=10, ...)

Adding Custom Loss and Regularization

We can also incorporate custom losses and regularization in our models:

python
class RegularizedModel(tf.keras.Model):
def __init__(self):
super(RegularizedModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(64, activation='relu')
self.dense2 = tf.keras.layers.Dense(10, activation='softmax')

def call(self, inputs, training=False):
x = self.dense1(inputs)
return self.dense2(x)

def train_step(self, data):
x, y = data

with tf.GradientTape() as tape:
y_pred = self(x, training=True)

# Main loss
loss = self.compiled_loss(y, y_pred)

# Add L2 regularization penalty
l2_loss = 0
for var in self.trainable_variables:
l2_loss += tf.nn.l2_loss(var)

# Add regularization to the loss
loss += 1e-5 * l2_loss

# Calculate gradients and update weights
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))

# Update metrics
self.compiled_metrics.update_state(y, y_pred)

# Return metrics
return {m.name: m.result() for m in self.metrics}

Summary

In this guide, we've explored TensorFlow's Model Subclassing API, which provides the ultimate flexibility for creating custom neural networks:

  • We learned how to create custom models by subclassing tf.keras.Model
  • We explored how to define custom layers and architectures with complex logic
  • We implemented custom training loops and integration with pre-trained models
  • We saw how to add custom losses and regularization techniques

Model Subclassing is ideal when you need:

  • Non-standard architectures
  • Custom layer interactions
  • Conditional execution paths
  • Fine-grained control over the training process

While the Functional and Sequential APIs are simpler, Model Subclassing gives you the power to implement cutting-edge architectures and research ideas using the full expressivity of Python.

Additional Resources and Exercises

Resources

Exercises

  1. Basic Exercise: Create a custom model with two dense layers and train it on the MNIST dataset.

  2. Intermediate Exercise: Implement a Residual Network (ResNet) block using Model Subclassing.

  3. Advanced Exercise: Create a multi-input model that processes both image and tabular data, then combines them before producing a final prediction.

  4. Research Exercise: Implement a recent research paper's architecture using Model Subclassing, and compare its performance with standard models.

  5. Practical Exercise: Build a custom model that incorporates attention mechanisms for a text classification task.

By mastering Model Subclassing, you'll be able to implement virtually any neural network architecture and have full control over your model's behavior.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)