Skip to main content

TensorFlow Transfer Learning

Introduction

Transfer learning is a powerful technique in deep learning that allows you to take a model trained on one task and repurpose it for a related task. Instead of building and training a neural network from scratch, transfer learning enables you to leverage pre-trained models that have already learned useful features from massive datasets like ImageNet.

In this tutorial, we'll explore how to use transfer learning with TensorFlow to solve image classification problems more efficiently. Transfer learning offers several advantages:

  • Reduced training time: Pre-trained models have already learned useful features
  • Less data required: You can achieve good results with smaller datasets
  • Better performance: Often leads to higher accuracy than training from scratch

Understanding Transfer Learning

What is Transfer Learning?

Transfer learning works on the principle that the features learned by a neural network in one domain can be useful in another related domain. For example, a model trained to recognize cats and dogs has already learned to detect edges, shapes, textures, and more complex patterns - these same features are useful for many other image classification tasks.

There are two main approaches to transfer learning:

  1. Feature Extraction: Using a pre-trained model as a fixed feature extractor
  2. Fine-Tuning: Further training (adapting) a pre-trained model on your new dataset

Let's visualize this concept:

Pre-trained Model (e.g., MobileNetV2)

├── Base Layers (feature extraction)
│ ├── Conv Layer 1
│ ├── Conv Layer 2
│ ├── ...
│ └── Conv Layer N

└── Classification Head
└── Dense Layer (1000 classes for ImageNet)

↓ Transfer Learning ↓

Your Custom Model

├── Base Layers (frozen or fine-tuned)
│ ├── Conv Layer 1 (frozen)
│ ├── Conv Layer 2 (frozen)
│ ├── ...
│ └── Conv Layer N (maybe fine-tuned)

└── New Classification Head
└── Dense Layer (your specific classes)

Getting Started with Transfer Learning in TensorFlow

Let's implement transfer learning using TensorFlow and the Keras API. We'll use a pre-trained MobileNetV2 model and adapt it for a flower classification task.

1. Import the Required Libraries

python
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import numpy as np

2. Load and Prepare the Dataset

For this example, we'll use the TensorFlow Flowers dataset:

python
# Download the flowers dataset
import tensorflow_datasets as tfds

# Load the flowers dataset
dataset, info = tfds.load('tf_flowers', as_supervised=True, with_info=True)
train_dataset, validation_dataset = dataset['train'].take(3670), dataset['train'].skip(3670)

num_classes = info.features['label'].num_classes
class_names = info.features['label'].names
print(f"Number of classes: {num_classes}")
print(f"Class names: {class_names}")

# Define image size based on the pre-trained model's requirements
IMG_SIZE = 224

# Function to preprocess images
def preprocess_image(image, label):
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = tf.cast(image, tf.float32) / 255.0 # Normalize to [0,1]
return image, label

# Apply preprocessing
batch_size = 32
train_dataset = train_dataset.map(preprocess_image).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
validation_dataset = validation_dataset.map(preprocess_image).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)

Output:

Number of classes: 5
Class names: ['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses']

3. Load a Pre-trained Model

Now, let's load the MobileNetV2 model pre-trained on ImageNet:

python
# Load the pre-trained model, excluding the classification head
base_model = MobileNetV2(weights='imagenet',
include_top=False,
input_shape=(IMG_SIZE, IMG_SIZE, 3))

# Freeze the base model so we don't change its weights during training
base_model.trainable = False

# View the model structure
print(f"Number of layers in the base model: {len(base_model.layers)}")

Output:

Number of layers in the base model: 155

4. Add a Custom Classification Head

We need to add new layers on top of the pre-trained model for our specific classification task:

python
# Create the model architecture
model = models.Sequential([
# Pre-trained base model
base_model,

# Add a global average pooling layer
layers.GlobalAveragePooling2D(),

# Add a dense hidden layer
layers.Dense(256, activation='relu'),

# Add dropout to prevent overfitting
layers.Dropout(0.5),

# Output layer with 5 units (one for each flower class)
layers.Dense(num_classes, activation='softmax')
])

# Compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Model summary
model.summary()

Output:

Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
mobilenetv2_1.00_224 (Model) (None, 7, 7, 1280) 2257984
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280) 0
_________________________________________________________________
dense (Dense) (None, 256) 327936
_________________________________________________________________
dropout (Dropout) (None, 256) 0
_________________________________________________________________
dense_1 (Dense) (None, 5) 1285
=================================================================
Total params: 2,587,205
Trainable params: 329,221
Non-trainable params: 2,257,984
_________________________________________________________________

5. Train the Model

Now, let's train the model with our feature extraction approach:

python
# Train the model
history = model.fit(
train_dataset,
epochs=10,
validation_data=validation_dataset
)

# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

Output (approximate results):

Epoch 1/10
115/115 [==============================] - 24s 208ms/step - loss: 0.8615 - accuracy: 0.6973 - val_loss: 0.3980 - val_accuracy: 0.8646
Epoch 2/10
115/115 [==============================] - 23s 200ms/step - loss: 0.4363 - accuracy: 0.8566 - val_loss: 0.3156 - val_accuracy: 0.8827
...
Epoch 10/10
115/115 [==============================] - 23s 201ms/step - loss: 0.1733 - accuracy: 0.9434 - val_loss: 0.1989 - val_accuracy: 0.9362

6. Fine-tuning the Model (Advanced Approach)

For even better performance, we can unfreeze some of the top layers of the base model and continue training at a very low learning rate:

python
# Unfreeze the top layers of the base model
base_model.trainable = True
for layer in base_model.layers[:100]:
layer.trainable = False

# Compile the model with a lower learning rate
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), # Much lower learning rate
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Continue training with fine-tuning
fine_tune_history = model.fit(
train_dataset,
epochs=5,
validation_data=validation_dataset
)

Output (approximate results):

Epoch 1/5
115/115 [==============================] - 42s 364ms/step - loss: 0.1611 - accuracy: 0.9475 - val_loss: 0.1425 - val_accuracy: 0.9482
...
Epoch 5/5
115/115 [==============================] - 42s 363ms/step - loss: 0.0739 - accuracy: 0.9756 - val_loss: 0.1284 - val_accuracy: 0.9633

7. Evaluate the Model and Make Predictions

Let's see our model in action by making predictions on some sample images:

python
# Function to make predictions on a batch of images
def make_predictions(dataset):
images, labels = next(iter(dataset.take(1)))
predictions = model.predict(images)
predicted_classes = np.argmax(predictions, axis=1)

# Display some sample images with predictions
plt.figure(figsize=(14, 8))
for i in range(min(9, len(images))):
plt.subplot(3, 3, i+1)
plt.imshow(images[i])
plt.title(f"True: {class_names[labels[i]]}\nPred: {class_names[predicted_classes[i]]}")
plt.axis('off')
plt.tight_layout()
plt.show()

# Make predictions
make_predictions(validation_dataset)

Real-world Applications of Transfer Learning

Transfer learning has revolutionized how we approach deep learning problems, especially in scenarios with limited data. Here are some practical applications:

1. Medical Image Analysis

Medical datasets are often small due to privacy concerns and the high cost of labeling. Transfer learning allows medical AI systems to benefit from models pre-trained on large natural image datasets, then fine-tuned for specific tasks like:

  • X-ray classification
  • Tumor detection in MRI scans
  • Skin lesion classification
python
# Example: Adapting a pre-trained model for X-ray classification
def create_chest_xray_model():
base_model = tf.keras.applications.DenseNet121(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)

# Freeze the base
base_model.trainable = False

model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(14, activation='sigmoid') # 14 conditions in ChestX-ray14 dataset
])

return model

2. Custom Object Detection

Businesses often need to detect specific objects relevant to their domain. Transfer learning makes this much more practical:

python
# Pseudocode for adapting a pre-trained object detection model
def create_custom_detector(num_classes):
# Load SSD MobileNet with COCO weights
base_model = tf.keras.applications.MobileNetV2(
weights='imagenet',
include_top=False,
input_shape=(300, 300, 3)
)

# Add custom detection heads
# ... (detection heads would be added here)

# Return the model with new detection heads for your classes
return detection_model

3. Mobile Applications

Transfer learning is particularly valuable for deploying AI on mobile devices:

  • Smaller models can be derived from larger pre-trained networks
  • On-device image recognition becomes feasible
  • Customized for specific use cases while maintaining efficiency

Best Practices for Transfer Learning

  1. Choose the right pre-trained model:

    • For limited computational resources: MobileNet, EfficientNet
    • For high accuracy: ResNet, VGG, Inception
    • For object detection: SSD, Faster R-CNN
  2. Data preprocessing:

    • Match the preprocessing of your data to that used by the pre-trained model
    • Use the same image size and normalization method
  3. Freezing vs. fine-tuning:

    • Small dataset → Freeze more layers
    • Large dataset → Fine-tune more layers
  4. Learning rate:

    • Use a small learning rate when fine-tuning (typically 10x smaller than when training from scratch)
  5. Data augmentation:

    • Always use data augmentation when you have limited training data

Summary

In this tutorial, you've learned how to implement transfer learning using TensorFlow and pre-trained models. We covered:

  • The concept and benefits of transfer learning
  • How to use a pre-trained model as a feature extractor
  • Adding custom classification layers for your specific task
  • Fine-tuning the pre-trained model for improved performance
  • Real-world applications and best practices

Transfer learning is an essential technique that allows you to build powerful deep learning models even with limited data and computational resources. By leveraging pre-trained models, you can achieve impressive results on a wide range of computer vision tasks with much less training time and data.

Additional Resources

Exercises

  1. Try implementing transfer learning with a different pre-trained model like ResNet50 or EfficientNetB0.
  2. Apply transfer learning to a different dataset, such as CIFAR-10 or a custom dataset of your choice.
  3. Experiment with different fine-tuning strategies: try unfreezing different numbers of layers and observe the effects.
  4. Implement a visualization to see what features the pre-trained model activates on for your dataset.
  5. Create a model that can distinguish between types of vehicles (cars, motorcycles, bicycles, trucks, buses) using transfer learning.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)