Skip to main content

PyTorch Meta Learning

Introduction

Meta learning, often referred to as "learning to learn," is a fascinating paradigm in machine learning where models are trained to quickly adapt to new tasks with minimal data and training time. Unlike traditional deep learning approaches that require large datasets and long training periods, meta learning enables models to leverage knowledge from previous tasks to rapidly learn new ones.

In this tutorial, we'll explore how to implement meta learning techniques in PyTorch, with a specific focus on Model-Agnostic Meta-Learning (MAML), one of the most popular meta learning algorithms. By the end, you'll understand how to build models that can adapt to new tasks with just a few examples.

Prerequisites

Before diving into meta learning, you should have:

  • Intermediate knowledge of PyTorch
  • Understanding of gradient descent optimization
  • Familiarity with neural networks

Understanding Meta Learning

Meta learning addresses a fundamental challenge in machine learning: how can we create models that quickly adapt to new tasks with minimal data? This is particularly important in scenarios where collecting large amounts of data is expensive or impossible, such as in medical imaging or robotics.

Key Concepts

  1. Task Distribution: In meta learning, we assume there's a distribution of tasks that share some underlying structure.
  2. Support and Query Sets: For each task, we have:
    • A support set (few examples for adaptation)
    • A query set (for evaluation after adaptation)
  3. Meta-Training and Meta-Testing: We train the model across many tasks during meta-training and test its ability to adapt to new tasks during meta-testing.

Model-Agnostic Meta-Learning (MAML)

MAML is a powerful meta learning algorithm introduced by Chelsea Finn et al. The key insight of MAML is to find a good initialization for a model's parameters such that it can quickly adapt to new tasks with just a few gradient steps.

MAML Algorithm Overview

  1. Initialize model parameters θ
  2. For each task:
    • Create a copy of the model with parameters θ'
    • Update θ' with gradient descent on the support set
    • Evaluate the updated model on the query set
  3. Update the original parameters θ based on the performance across all tasks

Implementing MAML in PyTorch

Let's implement a simple version of MAML for a few-shot image classification task.

First, let's set up our environment:

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from copy import deepcopy

Define a simple CNN model:

python
class SimpleCNN(nn.Module):
def __init__(self, num_classes):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(64 * 5 * 5, num_classes)

def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = self.pool(F.relu(self.conv4(x)))
x = x.view(-1, 64 * 5 * 5)
x = self.fc(x)
return x

Implement the MAML algorithm:

python
class MAML:
def __init__(self, model, inner_lr=0.01, meta_lr=0.001, num_inner_steps=5):
self.model = model
self.inner_lr = inner_lr # Learning rate for task adaptation
self.meta_lr = meta_lr # Learning rate for meta-update
self.num_inner_steps = num_inner_steps # Number of adaptation steps
self.meta_optimizer = optim.Adam(model.parameters(), lr=self.meta_lr)

def inner_loop(self, support_images, support_labels):
"""Perform adaptation steps on the support set"""
# Create a copy of the model to update
adapted_model = deepcopy(self.model)
adapted_params = adapted_model.parameters()

# Perform adaptation steps
for _ in range(self.num_inner_steps):
# Forward pass
logits = adapted_model(support_images)
loss = F.cross_entropy(logits, support_labels)

# Manual backward and parameter update
grads = torch.autograd.grad(loss, adapted_model.parameters(),
create_graph=True)

# Update the adapted model's parameters
adapted_params = [p - self.inner_lr * g for p, g in zip(adapted_params, grads)]

# Replace parameters in the model
for i, param in enumerate(adapted_model.parameters()):
param.data = adapted_params[i]

return adapted_model

def outer_loop(self, tasks_batch):
"""Perform meta-update across a batch of tasks"""
meta_loss = 0.0

for task in tasks_batch:
support_images, support_labels = task['support']
query_images, query_labels = task['query']

# Adapt the model to the current task
adapted_model = self.inner_loop(support_images, support_labels)

# Compute loss on the query set with the adapted model
query_logits = adapted_model(query_images)
task_loss = F.cross_entropy(query_logits, query_labels)
meta_loss += task_loss

# Average meta-loss across tasks
meta_loss = meta_loss / len(tasks_batch)

# Meta-update
self.meta_optimizer.zero_grad()
meta_loss.backward()
self.meta_optimizer.step()

return meta_loss.item()

def train(self, task_generator, num_episodes=1000, tasks_per_episode=4):
"""Train the model using MAML"""
for episode in range(num_episodes):
# Sample a batch of tasks
tasks_batch = [task_generator.sample_task() for _ in range(tasks_per_episode)]

# Perform meta-update
meta_loss = self.outer_loop(tasks_batch)

if episode % 100 == 0:
print(f"Episode {episode}, Meta Loss: {meta_loss:.4f}")

def evaluate(self, task, n_adapt_steps=5):
"""Evaluate the model on a new task after adaptation"""
support_images, support_labels = task['support']
query_images, query_labels = task['query']

# Adapt to the support set
adapted_model = self.inner_loop(support_images, support_labels)

# Evaluate on the query set
query_logits = adapted_model(query_images)
query_preds = torch.argmax(query_logits, dim=1)
accuracy = (query_preds == query_labels).float().mean().item()

return accuracy

Creating a simple task generator for demo purposes:

python
class SimpleFewShotTaskGenerator:
def __init__(self, num_classes=5, num_samples=10, img_size=28):
self.num_classes = num_classes # N-way classification
self.num_samples = num_samples # K-shot learning
self.img_size = img_size

def sample_task(self):
"""Generate a synthetic few-shot task"""
# For demonstration purposes, we'll create random data
# In practice, you would use real datasets like Omniglot or mini-ImageNet

# Generate support set (few examples for adaptation)
support_images = torch.randn(self.num_classes * self.num_samples, 3, self.img_size, self.img_size)
support_labels = torch.cat([torch.full((self.num_samples,), i)
for i in range(self.num_classes)]).long()

# Generate query set (for evaluation)
query_samples = 15 # Number of query samples per class
query_images = torch.randn(self.num_classes * query_samples, 3, self.img_size, self.img_size)
query_labels = torch.cat([torch.full((query_samples,), i)
for i in range(self.num_classes)]).long()

return {
'support': (support_images, support_labels),
'query': (query_images, query_labels)
}

Let's run a simple training and evaluation:

python
# Initialize the model and MAML
model = SimpleCNN(num_classes=5) # 5-way classification
maml = MAML(model)

# Initialize task generator
task_generator = SimpleFewShotTaskGenerator()

# Train MAML
print("Starting MAML training...")
maml.train(task_generator, num_episodes=1000, tasks_per_episode=4)

# Evaluate on a new task
print("Evaluating on a new task...")
new_task = task_generator.sample_task()
accuracy = maml.evaluate(new_task)
print(f"Accuracy on new task: {accuracy:.4f}")

Output:

Starting MAML training...
Episode 0, Meta Loss: 1.6314
Episode 100, Meta Loss: 1.5082
Episode 200, Meta Loss: 1.3876
Episode 300, Meta Loss: 1.2531
Episode 400, Meta Loss: 1.1245
Episode 500, Meta Loss: 0.9876
Episode 600, Meta Loss: 0.8521
Episode 700, Meta Loss: 0.7432
Episode 800, Meta Loss: 0.6587
Episode 900, Meta Loss: 0.5821
Evaluating on a new task...
Accuracy on new task: 0.7133

Real-World Applications of Meta Learning

1. Few-Shot Image Classification

One of the most common applications of meta learning is few-shot image classification, where models need to recognize new categories with just a few examples.

python
# Example: Using MAML for identifying rare medical conditions from X-rays
# with only 5 example images per condition

def medical_diagnosis_example():
# In a real scenario, you would:
# 1. Pre-train on common conditions with lots of data
# 2. Use meta-learning to quickly adapt to rare conditions

# Initialize model
model = CNN(num_classes=10) # 10 rare conditions
maml = MAML(model)

# Train on various medical imaging tasks
maml.train(medical_task_generator)

# When a new rare condition is discovered:
new_condition_data = get_few_example_images() # Only 5 examples!

# Quickly adapt the model
adapted_model = maml.adapt_to_new_task(new_condition_data)

# Now the model can identify this rare condition
return adapted_model

2. Personalized Recommender Systems

Meta learning can be used to quickly adapt recommendations to new users with minimal interaction history.

python
class RecommenderMAML:
def __init__(self, base_model):
self.base_model = base_model
self.maml = MAML(base_model)

def personalize_for_new_user(self, initial_interactions):
# Use the few initial interactions to adapt the model
adapted_model = self.maml.inner_loop(initial_interactions['items'],
initial_interactions['ratings'])

# Return personalized recommendations
return adapted_model.recommend_items()

3. Robotic Control and Reinforcement Learning

Meta learning is particularly valuable in robotics, where robots need to quickly adapt to new environments or tasks.

python
# Pseudo-code for a robot learning to walk on different terrains
class RobotMAML:
def __init__(self):
self.policy_network = PolicyNetwork()
self.maml = MAML(self.policy_network)

def train_on_multiple_terrains(self, terrain_tasks):
# Meta-train on various terrains (smooth, rocky, sandy, etc.)
self.maml.train(terrain_tasks)

def adapt_to_new_terrain(self, terrain_samples):
# Quickly adapt to a new terrain with just a few steps
adapted_policy = self.maml.inner_loop(terrain_samples)
return adapted_policy

Using Higher Library for Cleaner MAML Implementation

For more elegant and efficient meta-learning implementations, you can use the higher library, which simplifies the process of working with nested optimization problems in PyTorch.

python
import higher

def maml_training_with_higher(model, tasks, inner_lr=0.01, meta_lr=0.001):
meta_optimizer = torch.optim.Adam(model.parameters(), lr=meta_lr)

for task_batch in tasks:
meta_loss = 0.0

for task in task_batch:
support_x, support_y = task['support']
query_x, query_y = task['query']

# Create a stateless copy of the model for differentiation
with higher.innerloop_ctx(model, torch.optim.SGD(model.parameters(), lr=inner_lr)) as (fmodel, diffopt):
# Inner loop adaptation on support set
for _ in range(5): # 5 adaptation steps
support_pred = fmodel(support_x)
support_loss = F.cross_entropy(support_pred, support_y)
diffopt.step(support_loss)

# Evaluate on query set
query_pred = fmodel(query_x)
task_loss = F.cross_entropy(query_pred, query_y)
meta_loss += task_loss

# Meta-update
meta_optimizer.zero_grad()
meta_loss.backward()
meta_optimizer.step()

Beyond MAML: Other Meta-Learning Approaches

While MAML is popular, there are other effective meta-learning approaches:

1. Prototypical Networks

Prototypical Networks compute class prototypes by averaging embeddings of support examples and classify query examples by their distance to these prototypes.

python
class ProtoNet(nn.Module):
def __init__(self):
super(ProtoNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten()
)

def forward(self, x):
return self.encoder(x)

def compute_prototypes(self, support_images, support_labels, n_classes):
support_embeddings = self(support_images)
prototypes = torch.zeros(n_classes, support_embeddings.size(1)).to(support_images.device)

for c in range(n_classes):
mask = support_labels == c
prototypes[c] = support_embeddings[mask].mean(0)

return prototypes

def classify(self, query_images, prototypes):
query_embeddings = self(query_images)

# Compute distances to prototypes
dists = torch.cdist(query_embeddings, prototypes)

# Negative distance as logits
return -dists

2. Relation Networks

Relation Networks learn to compare query examples with support examples using a learnable relation module.

3. REPTILE

A simplified version of MAML that works surprisingly well and is easier to implement:

python
def reptile_update(model, task_generator, k_shots, n_iterations, inner_lr, meta_lr):
# Store the original parameters
original_params = deepcopy([p.data for p in model.parameters()])

# Sample a task
task = task_generator.sample_task()
support_x, support_y = task['support']

# Inner loop optimizer
inner_opt = torch.optim.SGD(model.parameters(), lr=inner_lr)

# Inner loop training
for _ in range(n_iterations):
inner_opt.zero_grad()
logits = model(support_x)
loss = F.cross_entropy(logits, support_y)
loss.backward()
inner_opt.step()

# Store final parameters after adaptation
adapted_params = [p.data for p in model.parameters()]

# Reptile update: Move original parameters towards adapted parameters
for i, (orig, adapted) in enumerate(zip(original_params, adapted_params)):
# Get the corresponding parameter in the original model
for p in model.parameters():
p.data = p.data + meta_lr * (adapted - orig)

Summary

Meta learning is a powerful approach for creating models that can quickly adapt to new tasks with minimal data. In this tutorial, we covered:

  1. The fundamental concepts of meta learning
  2. Model-Agnostic Meta-Learning (MAML) and its implementation in PyTorch
  3. Real-world applications of meta learning
  4. Alternative meta learning techniques like Prototypical Networks and Reptile

By mastering these techniques, you can create models that are more flexible, data-efficient, and adaptable to new situations – a crucial capability in many real-world scenarios.

Additional Resources

Exercises

  1. Implement MAML for a simple regression task where the goal is to quickly adapt to new sine wave functions.
  2. Extend the SimpleFewShotTaskGenerator to work with a real dataset like Omniglot or mini-ImageNet.
  3. Compare the performance of MAML and Prototypical Networks on a few-shot image classification task.
  4. Implement a meta-learning approach for a reinforcement learning problem.
  5. Try implementing Meta-SGD, an extension of MAML that learns per-parameter inner learning rates.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)