Skip to main content

PyTorch Multi-Task Learning

Introduction

Multi-Task Learning (MTL) is a powerful paradigm in deep learning where a single model is trained to perform multiple related tasks simultaneously. Instead of creating separate models for each task, MTL leverages the shared representations across tasks to improve overall performance, reduce overfitting, and save computational resources.

In this tutorial, we'll explore how to implement Multi-Task Learning using PyTorch. You'll learn:

  • The fundamental concepts behind Multi-Task Learning
  • How to structure a multi-task neural network
  • Techniques for balancing multiple loss functions
  • Practical implementation with real-world examples

What is Multi-Task Learning?

Multi-Task Learning is based on the principle that learning related tasks together can be beneficial as the model can leverage shared patterns across tasks. For instance, a model that simultaneously learns to detect objects, segment images, and estimate depth might perform better on all three tasks than three separate models specialized for each task.

Benefits of Multi-Task Learning:

  • Improved data efficiency: The model can learn from signals of related tasks
  • Reduced overfitting: Additional tasks act as regularization
  • Faster inference: One model instead of multiple models
  • Shared feature learning: Common representations beneficial for all tasks

Basic Multi-Task Learning Architecture in PyTorch

Let's start by implementing a simple multi-task neural network that can perform two related tasks: predicting both the category and the price of a product based on its description.

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiTaskModel(nn.Module):
def __init__(self, input_size, hidden_size, num_categories, price_range):
super(MultiTaskModel, self).__init__()

# Shared layers
self.embedding = nn.Linear(input_size, hidden_size)
self.shared_layer = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU()
)

# Task-specific layers
self.category_classifier = nn.Linear(hidden_size, num_categories)
self.price_predictor = nn.Linear(hidden_size, 1)
self.price_range = price_range

def forward(self, x):
# Shared representation
embedded = self.embedding(x)
shared_features = self.shared_layer(embedded)

# Task-specific predictions
category_logits = self.category_classifier(shared_features)
price = self.price_predictor(shared_features) * self.price_range

return {
'category': category_logits,
'price': price
}

This architecture has:

  1. A shared representation learned by initial layers
  2. Task-specific branches that specialize in each task

Multi-Task Loss Functions

With multiple tasks, we need to combine multiple loss functions. A common approach is to use a weighted sum of individual task losses:

python
def multi_task_loss(predictions, targets, loss_weights):
category_loss = F.cross_entropy(predictions['category'], targets['category'])
price_loss = F.mse_loss(predictions['price'], targets['price'])

# Calculate weighted loss
total_loss = loss_weights['category'] * category_loss + loss_weights['price'] * price_loss

return total_loss, {'category': category_loss.item(), 'price': price_loss.item()}

Dynamic Loss Weighting

Determining the right balance between task weights can be challenging. Let's implement a popular technique called "uncertainty weighting" proposed by Kendall et al. (2018):

python
class UncertaintyWeightedLoss(nn.Module):
def __init__(self, num_tasks):
super(UncertaintyWeightedLoss, self).__init__()
# Log variance parameters for each task
self.log_vars = nn.Parameter(torch.zeros(num_tasks))

def forward(self, losses):
# Get weights from log variances
precision = torch.exp(-self.log_vars)

# Calculate weighted loss
weighted_losses = [precision[i] * losses[i] + self.log_vars[i]/2 for i in range(len(losses))]

return sum(weighted_losses)

Training a Multi-Task Model

Now let's implement the training loop for our multi-task model:

python
def train_multi_task_model(model, train_loader, optimizer, loss_weights, epochs=10):
model.train()

for epoch in range(epochs):
total_loss = 0
task_losses = {'category': 0, 'price': 0}

for batch_idx, (data, targets) in enumerate(train_loader):
optimizer.zero_grad()

# Forward pass
predictions = model(data)

# Compute loss
loss, batch_task_losses = multi_task_loss(predictions, targets, loss_weights)

# Backward pass and optimization
loss.backward()
optimizer.step()

# Track metrics
total_loss += loss.item()
for task, task_loss in batch_task_losses.items():
task_losses[task] += task_loss

# Print epoch statistics
print(f"Epoch {epoch+1}/{epochs}")
print(f"Total Loss: {total_loss/len(train_loader):.4f}")
for task, task_loss in task_losses.items():
print(f"{task.capitalize()} Loss: {task_loss/len(train_loader):.4f}")
print("-" * 30)

Complete Example: E-commerce Product Analysis

Let's create a complete example for an e-commerce use case where we want to classify products and predict their prices:

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader

# Custom dataset for e-commerce products
class ProductDataset(Dataset):
def __init__(self, num_samples=1000, feature_dim=50):
self.num_samples = num_samples
self.feature_dim = feature_dim

# Generate synthetic data
self.features = torch.randn(num_samples, feature_dim)

# Generate synthetic targets
self.categories = torch.randint(0, 5, (num_samples,)) # 5 product categories

# Price depends somewhat on category (related tasks)
base_prices = torch.tensor([10.0, 25.0, 50.0, 100.0, 200.0])
self.prices = base_prices[self.categories] + 10 * torch.randn(num_samples, 1)

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
features = self.features[idx]
targets = {
'category': self.categories[idx],
'price': self.prices[idx]
}
return features, targets

# Create dataset and dataloaders
train_dataset = ProductDataset(num_samples=800)
test_dataset = ProductDataset(num_samples=200)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Initialize model
input_size = 50
hidden_size = 100
num_categories = 5
price_range = 300.0

model = MultiTaskModel(input_size, hidden_size, num_categories, price_range)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initial task weights
loss_weights = {'category': 1.0, 'price': 0.5}

# Train the model
train_multi_task_model(model, train_loader, optimizer, loss_weights, epochs=5)

# Evaluate model
def evaluate_model(model, test_loader):
model.eval()
total_samples = 0
correct_categories = 0
price_mse = 0

with torch.no_grad():
for data, targets in test_loader:
predictions = model(data)

# Category accuracy
predicted_categories = torch.argmax(predictions['category'], dim=1)
correct_categories += (predicted_categories == targets['category']).sum().item()

# Price MSE
price_mse += F.mse_loss(predictions['price'], targets['price'], reduction='sum').item()

total_samples += data.size(0)

category_accuracy = correct_categories / total_samples
avg_price_mse = price_mse / total_samples

print(f"Test Category Accuracy: {category_accuracy:.4f}")
print(f"Test Price MSE: {avg_price_mse:.4f}")

evaluate_model(model, test_loader)

Expected output:

Epoch 1/5
Total Loss: 2.3147
Category Loss: 1.5892
Price Loss: 1.4510
------------------------------
...
Epoch 5/5
Total Loss: 1.0732
Category Loss: 0.9887
Price Loss: 0.1691
------------------------------
Test Category Accuracy: 0.6150
Test Price MSE: 145.2834

Hard Parameter Sharing vs. Soft Parameter Sharing

There are two main approaches to multi-task learning architectures:

Hard Parameter Sharing

This is what we've implemented above - the model has shared layers followed by task-specific layers. This is the most common approach and helps prevent overfitting.

Soft Parameter Sharing

In soft parameter sharing, each task has its own model with its own parameters, but the parameters are regularized to be similar:

python
class SoftSharedModel(nn.Module):
def __init__(self, input_size, hidden_size, num_categories, price_range):
super(SoftSharedModel, self).__init__()

# Task-specific models
self.category_model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, num_categories)
)

self.price_model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
)

self.price_range = price_range

def forward(self, x):
category_logits = self.category_model(x)
price = self.price_model(x) * self.price_range

return {
'category': category_logits,
'price': price
}

def get_l2_reg_loss(self, alpha=0.01):
"""Calculate L2 regularization loss between corresponding layers"""
reg_loss = 0

# Get corresponding layers from each model
for i in range(0, len(self.category_model), 2): # Skip activation layers
if isinstance(self.category_model[i], nn.Linear) and isinstance(self.price_model[i], nn.Linear):
# L2 regularization between corresponding weights
reg_loss += torch.sum((self.category_model[i].weight - self.price_model[i].weight) ** 2)

return alpha * reg_loss

Real-World Application: Multi-Task Computer Vision

Multi-task learning is widely used in computer vision. Let's build a simple model that performs image classification and segmentation:

python
class MultiTaskVisionModel(nn.Module):
def __init__(self):
super(MultiTaskVisionModel, self).__init__()

# Use a pre-trained model as the shared backbone
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# Remove the classification head
self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])

# Task-specific heads
# Classification head
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, 10) # 10 classes
)

# Segmentation head
self.segmenter = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=3, padding=1),
nn.Sigmoid()
)

def forward(self, x):
# Shared features
features = self.backbone(x)

# Task-specific outputs
classification = self.classifier(features)
segmentation = self.segmenter(features)

return {
'class': classification,
'segmentation': segmentation
}

Dealing with Task Imbalance

Sometimes one task can dominate the learning process. Here are techniques to address this:

Gradient Normalization

python
def train_with_gradient_norm(model, data, optimizer, losses_func):
optimizer.zero_grad()

# Forward pass
outputs = model(data)

# Calculate individual task losses
task_losses = losses_func(outputs, targets)

# Backward pass with normalized gradients
for i, loss in enumerate(task_losses):
if i == 0:
loss.backward(retain_graph=True)
else:
# Get gradients from previous tasks
grads = []
for param in model.parameters():
if param.grad is not None:
grads.append(param.grad.view(-1))

# Flatten gradients
grads = torch.cat(grads)

# Calculate current gradients
optimizer.zero_grad()
loss.backward(retain_graph=(i < len(task_losses)-1))

# Get current gradients
curr_grads = []
for param in model.parameters():
if param.grad is not None:
curr_grads.append(param.grad.view(-1))

curr_grads = torch.cat(curr_grads)

# Project for orthogonality
proj = torch.dot(curr_grads, grads) / torch.dot(grads, grads)
curr_grads = curr_grads - proj * grads

# Put projected gradients back
idx = 0
for param in model.parameters():
if param.grad is not None:
param_size = param.grad.numel()
param.grad = curr_grads[idx:idx+param_size].view_as(param.grad)
idx += param_size

# Update weights
optimizer.step()

Summary

In this tutorial, we've explored Multi-Task Learning with PyTorch, covering:

  • Basic principles of multi-task learning
  • How to build multi-task models with shared representations
  • Balancing multiple loss functions
  • Hard vs. soft parameter sharing
  • Advanced techniques for handling task imbalance
  • Real-world applications in e-commerce and computer vision

Multi-task learning is a powerful technique that can improve data efficiency, reduce overfitting, and potentially improve overall performance when tasks are related. As you build more complex deep learning systems, consider whether your tasks might benefit from a multi-task approach.

Additional Resources

Exercises

  1. Modify the e-commerce example to add a third task: predicting whether a product is in stock (binary classification).
  2. Implement the uncertainty weighting technique and compare it with fixed loss weights.
  3. Try adapting a pre-trained vision model (like ResNet) for multiple image tasks using transfer learning.
  4. Experiment with different shared layer architectures to see how they affect performance on individual tasks.
  5. Build a multi-task NLP model that can perform both sentiment analysis and named entity recognition on the same text input.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)