PyTorch Transfer Learning
Introduction
Transfer learning is a powerful machine learning technique where a model developed for one task is reused as the starting point for a model on a second task. In the context of deep learning and computer vision, this usually means taking a pre-trained neural network (often trained on a large dataset like ImageNet) and fine-tuning it for your specific use case.
This approach offers several significant advantages:
- Reduced training time: Pre-trained models have already learned to detect various features, so you don't need to train from scratch.
- Less data required: You can achieve good results with much smaller datasets than would be needed for training from scratch.
- Better performance: In many cases, transfer learning leads to better accuracy than training models from scratch, especially with limited data.
In this tutorial, we'll explore how to implement transfer learning with PyTorch for computer vision tasks.
Prerequisites
Before diving in, you should have:
- Basic understanding of Python programming
- Familiarity with PyTorch fundamentals
- Understanding of neural networks and convolutional neural networks (CNNs)
- PyTorch and torchvision installed
pip install torch torchvision
Understanding Transfer Learning Approaches
There are two main approaches to transfer learning:
- Feature Extraction: Use the pre-trained model as a fixed feature extractor, replacing only the last layer(s).
- Fine-tuning: Adapt the pre-trained weights by continuing the backpropagation through some or all layers of the network.
Let's explore both techniques using PyTorch.
1. Feature Extraction with Pre-trained Models
In this approach, we'll use a pre-trained model as a feature extractor by freezing its weights and only training a new classifier on top.
Step 1: Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
import time
import os
import copy
Step 2: Prepare your data
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# Example data directory structure:
# data/
# ├── train/
# │ ├── class1/
# │ └── class2/
# └── val/
# ├── class1/
# └── class2/
data_dir = 'data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Step 3: Load a pre-trained model and modify the classifier
# Load pre-trained ResNet-18 model
model_ft = models.resnet18(pretrained=True)
# Freeze all the network except the final layer
for param in model_ft.parameters():
param.requires_grad = False
# Replace the final fully connected layer
# ResNet-18 has fc as the last layer name
num_ftrs = model_ft.fc.in_features
# Here we're assuming a classification problem with 2 classes
model_ft.fc = nn.Linear(num_ftrs, 2)
# Move the model to the appropriate device
model_ft = model_ft.to(device)
# Define loss function
criterion = nn.CrossEntropyLoss()
# Observe that only parameters of the final layer are being optimized
optimizer_ft = optim.SGD(model_ft.fc.parameters(), lr=0.001, momentum=0.9)
# Learning rate scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
Step 4: Train the model
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward
# Track history only in train phase
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Backward + optimize only in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# Deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:.4f}')
# Load best model weights
model.load_state_dict(best_model_wts)
return model
# Train the model
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)
2. Fine-tuning the Pre-trained Model
In fine-tuning, we allow some or all of the pre-trained model's weights to be updated during training.
# Load a pre-trained model again
model_ft = models.resnet18(pretrained=True)
# Fine-tune all layers
# (alternatively, you could freeze some earlier layers)
for param in model_ft.parameters():
param.requires_grad = True
# Replace final fully connected layer
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2) # Adjust to your number of classes
# Move to device
model_ft = model_ft.to(device)
# Define loss function
criterion = nn.CrossEntropyLoss()
# Optimize all parameters
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# Learning rate scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
# Train and evaluate
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)
Choosing the Right Pre-trained Model
PyTorch offers numerous pre-trained models through the torchvision.models
module. Here are some popular choices:
# Available models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
efficientnet = models.efficientnet_b0(pretrained=True)
When choosing a model, consider:
- Model size: Larger models might be more accurate but require more memory and computation
- Speed: Some models are designed for efficiency
- Task similarity: Choose models pre-trained on tasks similar to yours
Real-world Application: Cat vs. Dog Classifier
Let's implement a complete example of a cat vs. dog classifier using transfer learning.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
# Data preparation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = 'cat_dog_data' # Assumes dataset with train/val folders containing cat and dog images
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Visualize some training images
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
# Load a pretrained model and modify it
model_ft = models.resnet34(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2) # 2 classes - cat and dog
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
# Use the train_model function defined earlier
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)
# Save the trained model
torch.save(model_ft.state_dict(), 'cat_dog_classifier.pth')
# Visualize model predictions
def visualize_model(model, num_images=6):
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure()
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.axis('off')
ax.set_title(f'predicted: {class_names[preds[j]]}')
imshow(inputs.cpu().data[j])
if images_so_far == num_images:
model.train(mode=was_training)
return
model.train(mode=was_training)
# Visualize some sample predictions
visualize_model(model_ft)
Using Your Model for Inference
After training, you'll want to use your model to make predictions on new images:
def predict_image(image_path, model):
# Load and preprocess the image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path)
image = transform(image).unsqueeze(0) # Add batch dimension
# Move to device and evaluate
model.eval()
with torch.no_grad():
image = image.to(device)
outputs = model(image)
_, predicted = torch.max(outputs, 1)
return class_names[predicted.item()]
# Example usage
image_path = 'path/to/new/image.jpg'
prediction = predict_image(image_path, model_ft)
print(f"Predicted class: {prediction}")
Tips for Effective Transfer Learning
- Choose the right pre-trained model: Consider both performance and computational requirements.
- Data augmentation: Often crucial when fine-tuning with small datasets.
- Gradual unfreezing: Start by training only the last layer, then gradually unfreeze and train deeper layers.
- Lower learning rates: Use smaller learning rates when fine-tuning to avoid drastically changing the pre-trained weights.
- Monitor validation performance: Watch for overfitting, especially with small datasets.
When Not to Use Transfer Learning
Transfer learning isn't always the best choice:
- When your data is very different from the pre-training dataset
- When you have a very large dataset specific to your problem
- When your task is fundamentally different from the pre-trained model's task
Summary
Transfer learning is an incredibly powerful technique that enables you to build high-performing computer vision models with limited data and computational resources. In this tutorial, you've learned how to:
- Choose appropriate pre-trained models from torchvision
- Implement feature extraction by freezing pre-trained layers
- Fine-tune a pre-trained model on your own dataset
- Apply data augmentation to improve performance
- Use your trained model for making predictions on new images
By leveraging the knowledge contained in pre-trained networks, you can achieve remarkable results even with modest datasets. This approach has revolutionized computer vision by making powerful deep learning techniques accessible to a much wider range of applications and developers.
Additional Resources and Exercises
Resources
Exercises
-
Different Pre-trained Models: Modify the examples to use different pre-trained models (VGG, DenseNet, EfficientNet) and compare their performance.
-
Different Freezing Strategies: Experiment with freezing different portions of the network. Try freezing only the first few layers, or implementing progressive unfreezing.
-
Custom Dataset: Apply transfer learning to a domain-specific dataset of your choice. For example, train a model to recognize different types of plants or food dishes.
-
Learning Rate Exploration: Experiment with different learning rate schedules for fine-tuning and observe how they affect the model's performance.
-
Model Distillation: Research and implement a simple version of knowledge distillation, where a larger pre-trained model is used to train a smaller, more efficient model.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)