TensorFlow Fine-tuning
Introduction
Fine-tuning is a powerful technique in deep learning that allows you to take advantage of pre-trained models. Instead of training a complex neural network from scratch—which requires massive amounts of data and computational resources—fine-tuning lets you "borrow" knowledge from existing models and adapt them to your specific task.
In this tutorial, we'll explore how to fine-tune pre-trained Convolutional Neural Networks (CNNs) using TensorFlow. This approach is especially valuable when you have limited training data or computational resources, as it significantly reduces training time while often improving performance.
What is Fine-tuning?
Fine-tuning builds on a concept called transfer learning, where knowledge gained while solving one problem is applied to a different but related problem. In the context of CNNs:
- You start with a pre-trained model (like VGG16, ResNet, or MobileNet) that has been trained on a large dataset (typically ImageNet with over 1 million images)
- You remove the final classification layer of the pre-trained model
- You add your own classification layer(s) specific to your task
- You "fine-tune" some or all of the model's weights on your specific dataset
Why Fine-tune Instead of Training from Scratch?
- Less data needed: Pre-trained models already know how to extract useful features from images
- Faster convergence: Training starts from a good initialization point rather than random weights
- Better performance: Especially when your dataset is small (hundreds or thousands vs. millions of images)
- Less computational resources: You can achieve state-of-the-art results with less training time
Prerequisites
Before we start, make sure you have the following installed:
pip install tensorflow tensorflow-hub matplotlib numpy pillow
Basic Fine-tuning Workflow
Let's walk through a complete example of fine-tuning a MobileNetV2 model for a custom classification task.
1. Import Required Libraries
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
2. Prepare Your Dataset
For this example, let's use a simplified flower classification dataset:
# Download a sample dataset (5 flower types)
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
# Count the images
image_count = len(list(data_dir.glob('*/*.jpg')))
print(f"Total images: {image_count}")
# Check class names
class_names = [item.name for item in data_dir.glob('*/') if item.is_dir()]
print(f"Classes: {class_names}")
Output:
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/? [......] - ETA: 0s
Total images: 3670
Classes: ['sunflowers', 'roses', 'dandelion', 'daisy', 'tulips']
3. Set Up Data Generators
# Parameters
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
# Create training and validation splits
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
zoom_range=0.2,
horizontal_flip=True,
validation_split=0.2 # 20% for validation
)
# Load and prepare the training dataset
train_generator = train_datagen.flow_from_directory(
data_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='training'
)
# Load and prepare the validation dataset
validation_generator = train_datagen.flow_from_directory(
data_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='validation'
)
Output:
Found 2939 images belonging to 5 classes.
Found 731 images belonging to 5 classes.
4. Create the Base Model from a Pre-trained ConvNet
# Load the MobileNetV2 model but exclude the classification layers
base_model = tf.keras.applications.MobileNetV2(
input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
include_top=False,
weights='imagenet'
)
# Initially, we freeze the base model to prevent its weights from being updated
base_model.trainable = False
# Let's look at the base model architecture
print(f"Number of layers in the base model: {len(base_model.layers)}")
Output:
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
9406464/9406464 [==============================] - 0s 0us/step
Number of layers in the base model: 154
5. Build the Complete Fine-tuned Model
# Add our custom classification head
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(5, activation='softmax') # 5 classes for our flower dataset
])
# Compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Model summary
model.summary()
6. Train the Model (First Phase: Only Train the Head)
# Train only the top layers (randomly initialized) for a few epochs
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // BATCH_SIZE,
epochs=10,
validation_data=validation_generator,
validation_steps=validation_generator.samples // BATCH_SIZE
)
Output:
Epoch 1/10
91/91 [==============================] - 36s 395ms/step - loss: 1.1153 - accuracy: 0.5493 - val_loss: 0.8277 - val_accuracy: 0.7123
Epoch 2/10
91/91 [==============================] - 35s 385ms/step - loss: 0.6975 - accuracy: 0.7599 - val_loss: 0.5854 - val_accuracy: 0.8060
...
Epoch 10/10
91/91 [==============================] - 35s 387ms/step - loss: 0.3367 - accuracy: 0.8801 - val_loss: 0.3940 - val_accuracy: 0.8626
7. Unfreeze and Fine-tune the Model (Second Phase)
Now that the head is trained, we can unfreeze some of the layers in our base model and fine-tune them:
# Unfreeze all or part of the base model
# Here we'll unfreeze the top 50 layers
base_model.trainable = True
# Freeze all the layers before the ones we want to fine-tune
for layer in base_model.layers[:-50]:
layer.trainable = False
# Recompile the model with a lower learning rate
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001), # Lower learning rate
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Continue training with fine-tuning
fine_tune_history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // BATCH_SIZE,
epochs=10,
validation_data=validation_generator,
validation_steps=validation_generator.samples // BATCH_SIZE,
initial_epoch=10 # Start from where we left off
)
Output:
Epoch 11/20
91/91 [==============================] - 67s 740ms/step - loss: 0.2608 - accuracy: 0.9092 - val_loss: 0.3064 - val_accuracy: 0.8982
Epoch 12/20
91/91 [==============================] - 67s 742ms/step - loss: 0.1980 - accuracy: 0.9320 - val_loss: 0.2738 - val_accuracy: 0.9053
...
Epoch 20/20
91/91 [==============================] - 66s 736ms/step - loss: 0.0813 - accuracy: 0.9730 - val_loss: 0.2506 - val_accuracy: 0.9208
8. Visualize Training Results
# Combine the histories from both training phases
acc = history.history['accuracy'] + fine_tune_history.history['accuracy']
val_acc = history.history['val_accuracy'] + fine_tune_history.history['val_accuracy']
loss = history.history['loss'] + fine_tune_history.history['loss']
val_loss = history.history['val_loss'] + fine_tune_history.history['val_loss']
plt.figure(figsize=(14, 4))
# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.plot([9, 9], [0, 1], 'r--', label='Start Fine Tuning')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
# Plot loss
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.plot([9, 9], [0, 1.5], 'r--', label='Start Fine Tuning')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.ylim([0, 1.5])
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()
9. Evaluate and Save the Model
# Evaluate the model
loss, accuracy = model.evaluate(validation_generator)
print(f"Validation accuracy: {accuracy:.3f}")
# Save the model
model.save('fine_tuned_flower_model.h5')
Output:
23/23 [==============================] - 7s 303ms/step - loss: 0.2506 - accuracy: 0.9208
Validation accuracy: 0.921
Advanced Fine-tuning Strategies
1. Feature Extraction vs. Fine-tuning
There are two main approaches to transfer learning:
Feature Extraction (what we did in the first phase)
- Keep the convolutional base frozen
- Only train the newly added classifier layers
- Faster and works well when your new dataset is small or similar to the original dataset
Fine-tuning (what we did in the second phase)
- Unfreeze some layers of the pre-trained network
- Train both the new layers and the unfrozen layers
- More powerful but risks overfitting on small datasets
2. Progressive Fine-tuning
A common technique is to gradually unfreeze layers from top to bottom, since:
- Higher layers (closer to output) contain more specialized features
- Lower layers (closer to input) contain more generic features (edges, colors, textures)
# Example of progressive unfreezing
# First, train only the classifier layers (feature extraction)
# Then, unfreeze the last block of the base model
# Finally, unfreeze more blocks as needed
# Start with all layers frozen
base_model.trainable = False
# Train for a few epochs with just the head...
# Then unfreeze the last block and train with a lower learning rate
base_model.trainable = True
for layer in base_model.layers[:-30]: # Freeze all except the last 30 layers
layer.trainable = False
# Train for a few more epochs...
# Finally, unfreeze more layers if needed
for layer in base_model.layers[:-50]: # Freeze all except the last 50 layers
layer.trainable = False
# Train with an even lower learning rate...
3. Choosing Which Pre-trained Model to Use
TensorFlow provides many pre-trained models to choose from:
# For large datasets or high accuracy needs:
base_model = tf.keras.applications.ResNet50V2(weights='imagenet', include_top=False)
# For mobile or edge devices:
base_model = tf.keras.applications.MobileNetV2(weights='imagenet', include_top=False)
# For a balance between size and performance:
base_model = tf.keras.applications.EfficientNetB0(weights='imagenet', include_top=False)
4. Using TensorFlow Hub
TF Hub provides an even easier way to use pre-trained models:
import tensorflow_hub as hub
# Load a feature vector model from TF Hub
feature_extractor_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
feature_extractor_layer = hub.KerasLayer(
feature_extractor_url,
input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
trainable=False
)
# Build your model using the feature extractor
model = tf.keras.Sequential([
feature_extractor_layer,
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(5, activation='softmax')
])
Real-world Application: Skin Lesion Classification
Let's consider a practical medical application: fine-tuning a model to classify skin lesions as benign or malignant.
# This is just a code sketch - in reality you'd need the actual dataset
# 1. Start with a pre-trained model good at feature extraction
base_model = tf.keras.applications.DenseNet121(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
# 2. Freeze the base model
base_model.trainable = False
# 3. Create the complete model with a custom head
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1, activation='sigmoid') # Binary classification
])
# 4. Compile with class weights to handle imbalanced data
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC()]
)
# 5. Use callbacks for early stopping and model checkpointing
callbacks = [
tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
tf.keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)
]
# 6. Train with data augmentation specific to medical imaging
# (In a real scenario, you would configure your data generators here)
# 7. After initial training, fine-tune the last few blocks
for layer in base_model.layers[-30:]:
layer.trainable = True
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # Lower learning rate
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC()]
)
# 8. Continue training with fine-tuning
# model.fit(...)
This approach is commonly used in medical imaging where datasets are limited and the cost of misclassification is high.
Common Pitfalls and Solutions
Overfitting
Problem: The fine-tuned model performs well on training data but poorly on new data.
Solutions:
- Use more aggressive data augmentation
- Apply regularization (Dropout, L2)
- Unfreeze fewer layers
- Use a smaller learning rate
- Implement early stopping
Catastrophic Forgetting
Problem: The model "forgets" what it learned from the pre-trained weights.
Solutions:
- Use a very small learning rate when fine-tuning
- Implement gradual unfreezing
- Consider using techniques like knowledge distillation
Preprocessing Mismatch
Problem: The preprocessing for your images doesn't match what the pre-trained model expects.
Solution: Always use the preprocessing function that comes with the pre-trained model:
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
# Correct preprocessing for MobileNetV2
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=20,
zoom_range=0.2,
horizontal_flip=True,
validation_split=0.2
)
Summary
In this tutorial, we've covered:
- The concept of fine-tuning - leveraging pre-trained models to solve new tasks
- Step-by-step fine-tuning process - from loading pre-trained models to training and evaluation
- Advanced strategies - progressive unfreezing, model selection, and TF Hub integration
- Real-world application - applying fine-tuning to medical image classification
- Common pitfalls and solutions - avoiding overfitting and other common issues
Fine-tuning pre-trained CNN models is one of the most powerful techniques in computer vision. It allows you to achieve excellent results with limited data and computational resources, making advanced deep learning accessible even to beginners.
Additional Resources
- TensorFlow Transfer Learning Guide
- TensorFlow Hub for pre-trained models
- Keras Applications documentation
- CS231n Stanford Course for deeper understanding of CNNs
Exercises
- Fine-tune a different pre-trained model (like ResNet50 or EfficientNet) on the same flower dataset and compare the results.
- Apply the fine-tuning approach to a different dataset of your choice (e.g., food images, animal species).
- Experiment with different layer freezing strategies to find the optimal approach for your dataset.
- Implement a learning rate scheduler that gradually decreases the learning rate during fine-tuning.
- Use TensorBoard to visualize the training process and feature maps of the fine-tuned model.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)