TensorFlow Meta-Learning


Meta-learning, often described as "learning to learn," represents an advanced paradigm in machine learning where systems improve their learning abilities over time and across different tasks. Unlike traditional machine learning approaches that start from scratch for each new task, meta-learning leverages experience from previous learning tasks to rapidly adapt to new, similar tasks with minimal data.

In this tutorial, we'll explore how to implement meta-learning techniques using TensorFlow. This approach is particularly valuable for scenarios where you have limited data for new tasks or need models that can quickly adapt to changing environments.

Key Concepts in Meta-Learning

Before diving into code, let's understand some fundamental concepts:

  1. Few-shot Learning: Training models to recognize patterns from just a few examples
  2. Model-Agnostic Meta-Learning (MAML): A popular meta-learning algorithm that optimizes for quick adaptation
  3. Learning to Learn: The process where a model improves its learning algorithm through experience
  4. Support and Query Sets: Terminology used in meta-learning for training and testing samples

Implementing a Simple Meta-Learning Framework

Let's start by setting up a basic meta-learning framework using TensorFlow:

import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, models

# Set random seeds for reproducibility

Creating a Meta-Model Architecture

We'll first define a simple neural network that will serve as our base model:

def create_model():
model = models.Sequential([
layers.Dense(64, activation='relu', input_shape=(28*28,)),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax')
return model

Implementing MAML (Model-Agnostic Meta-Learning)

Now, let's implement a simplified version of MAML, one of the most popular meta-learning algorithms:

class MAML:
def __init__(self, model_fn, inner_lr=0.01, meta_lr=0.001):
self.model_fn = model_fn
self.meta_model = model_fn()
self.inner_lr = inner_lr
self.meta_optimizer = tf.keras.optimizers.Adam(meta_lr)

def adapt(self, support_x, support_y, num_inner_steps=1):
"""Adapt model parameters to a new task based on support examples"""
adapted_model = self.model_fn()

# Inner loop optimization
for _ in range(num_inner_steps):
with tf.GradientTape() as tape:
predictions = adapted_model(support_x, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(support_y, predictions)
loss = tf.reduce_mean(loss)

# Compute gradients and update the model's parameters
gradients = tape.gradient(loss, adapted_model.trainable_variables)
for i, var in enumerate(adapted_model.trainable_variables):
var.assign_sub(self.inner_lr * gradients[i])

return adapted_model

def meta_train_step(self, tasks):
"""Perform a meta-training step on a batch of tasks"""
meta_loss = 0
meta_gradients = [tf.zeros_like(var) for var in self.meta_model.trainable_variables]

with tf.GradientTape() as meta_tape:
for task in tasks:
support_x, support_y, query_x, query_y = task

# Adapt model to support set
adapted_model = self.adapt(support_x, support_y)

# Evaluate on query set
query_predictions = adapted_model(query_x, training=True)
query_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(query_y, query_predictions)

meta_loss += query_loss / len(tasks)

# Compute meta-gradients
meta_gradients = meta_tape.gradient(meta_loss, self.meta_model.trainable_variables)

# Apply meta-gradients
self.meta_optimizer.apply_gradients(zip(meta_gradients, self.meta_model.trainable_variables))

return meta_loss

Preparing Data for Meta-Learning

For meta-learning, we need to structure our data differently. Let's prepare the MNIST dataset as a collection of tasks:

def prepare_mnist_tasks(n_way=5, k_shot=1, n_query=5, n_tasks=100):
Prepare meta-learning tasks from MNIST dataset
n_way: Number of classes per task
k_shot: Number of support examples per class
n_query: Number of query examples per class
n_tasks: Total number of tasks to generate
# Load and preprocess MNIST
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28*28).astype('float32') / 255.0

# Group examples by class
class_examples = [x_train[y_train == i] for i in range(10)]

tasks = []
for _ in range(n_tasks):
# Randomly select n_way classes
selected_classes = np.random.choice(10, n_way, replace=False)

support_x = []
support_y = []
query_x = []
query_y = []

for i, cls in enumerate(selected_classes):
# Select k_shot + n_query examples from this class
examples = class_examples[cls]
indices = np.random.choice(len(examples), k_shot + n_query, replace=False)

# Split into support and query sets
support_indices = indices[:k_shot]
query_indices = indices[k_shot:k_shot + n_query]

support_y.append(np.ones(k_shot) * i) # Relabel as 0, 1, 2, ...

query_y.append(np.ones(n_query) * i)

# Combine and shuffle support set
support_x = np.vstack(support_x)
support_y = np.concatenate(support_y)
perm = np.random.permutation(len(support_y))
support_x = support_x[perm]
support_y = support_y[perm]

# Combine and shuffle query set
query_x = np.vstack(query_x)
query_y = np.concatenate(query_y)
perm = np.random.permutation(len(query_y))
query_x = query_x[perm]
query_y = query_y[perm]

tasks.append((support_x, support_y, query_x, query_y))

return tasks

Training and Evaluating a Meta-Learning Model

Now let's train our meta-learning framework:

# Prepare meta-training tasks
meta_train_tasks = prepare_mnist_tasks(n_way=5, k_shot=5, n_query=10, n_tasks=1000)
meta_test_tasks = prepare_mnist_tasks(n_way=5, k_shot=5, n_query=10, n_tasks=100)

# Initialize MAML
maml = MAML(create_model, inner_lr=0.05, meta_lr=0.001)

# Meta-training loop
batch_size = 4
n_epochs = 5

for epoch in range(n_epochs):
# Shuffle tasks

total_loss = 0
n_batches = len(meta_train_tasks) // batch_size

for i in range(n_batches):
task_batch = meta_train_tasks[i * batch_size:(i + 1) * batch_size]
loss = maml.meta_train_step(task_batch)
total_loss += loss

if i % 10 == 0:
print(f"Epoch {epoch+1}, Batch {i}/{n_batches}, Loss: {loss:.4f}")

avg_loss = total_loss / n_batches
print(f"Epoch {epoch+1} completed, Average Loss: {avg_loss:.4f}")

# Meta-testing
test_accuracies = []

for task in meta_test_tasks:
support_x, support_y, query_x, query_y = task

# Adapt model to support set (5 inner steps for better adaptation)
adapted_model = maml.adapt(support_x, support_y, num_inner_steps=5)

# Evaluate on query set
query_predictions = adapted_model(query_x)
query_predicted_classes = tf.argmax(query_predictions, axis=1)

accuracy = tf.reduce_mean(
tf.cast(tf.equal(query_predicted_classes, tf.cast(query_y, tf.int64)), tf.float32)

mean_accuracy = tf.reduce_mean(test_accuracies)
print(f"Meta-test accuracy: {mean_accuracy:.4f}")

Real-World Application: Few-Shot Image Classification

Let's implement a more practical example using meta-learning for few-shot image classification with the Omniglot dataset, which is commonly used for meta-learning benchmarks:

# First, let's download and prepare the Omniglot dataset
import tensorflow_datasets as tfds

# Load Omniglot dataset
omniglot_ds = tfds.load('omniglot', split='train+test', as_supervised=True)

def prepare_omniglot_tasks(dataset, n_way=5, k_shot=1, n_query=5, n_tasks=100):
"""Prepare meta-learning tasks from Omniglot dataset"""
# Process dataset
images = []
labels = []

for img, label in dataset:
images.append(tf.image.resize(img, (28, 28)))

images = tf.stack(images) / 255.0 # Normalize
images = tf.reshape(images, (-1, 28*28))
labels = tf.stack(labels)

# Group by class
unique_labels = tf.unique(labels)[0]
n_classes = len(unique_labels)

class_examples = [images[labels == label] for label in unique_labels]
class_sizes = [len(examples) for examples in class_examples]

# Generate tasks
tasks = []
for _ in range(n_tasks):
# Select classes with enough examples
valid_classes = [i for i, size in enumerate(class_sizes) if size >= k_shot + n_query]
if len(valid_classes) < n_way:

selected_classes = np.random.choice(valid_classes, n_way, replace=False)

support_x = []
support_y = []
query_x = []
query_y = []

for i, cls_idx in enumerate(selected_classes):
examples = class_examples[cls_idx].numpy()
indices = np.random.choice(len(examples), k_shot + n_query, replace=False)

support_indices = indices[:k_shot]
query_indices = indices[k_shot:k_shot + n_query]

support_y.append(np.ones(k_shot) * i)

query_y.append(np.ones(n_query) * i)

support_x = np.vstack(support_x)
support_y = np.concatenate(support_y)
query_x = np.vstack(query_x)
query_y = np.concatenate(query_y)

tasks.append((support_x, support_y, query_x, query_y))

return tasks

Let's create a CNN model specifically for Omniglot:

def create_omniglot_model():
model = models.Sequential([
layers.Reshape((28, 28, 1), input_shape=(28*28,)),
layers.Conv2D(32, 3, activation='relu', padding='same'),
layers.Conv2D(64, 3, activation='relu', padding='same'),
layers.Dense(64, activation='relu'),
layers.Dense(5, activation='softmax') # 5-way classification
return model

# Create MAML instance with omniglot model
omniglot_maml = MAML(create_omniglot_model, inner_lr=0.1, meta_lr=0.001)

# Prepare tasks and train (code would be similar to MNIST example above)

Prototypical Networks: Another Meta-Learning Approach

MAML is just one approach to meta-learning. Let's implement another popular technique called Prototypical Networks:

class ProtoNet:
def __init__(self, input_shape):
self.encoder = models.Sequential([
layers.Reshape((28, 28, 1), input_shape=(28*28,)),
layers.Conv2D(64, 3, activation='relu', padding='same'),
layers.Conv2D(64, 3, activation='relu', padding='same'),
layers.Conv2D(64, 3, activation='relu', padding='same'),
layers.Dense(64) # Embedding dimension

self.optimizer = tf.keras.optimizers.Adam(0.001)

def compute_prototypes(self, support_x, support_y, n_way):
"""Compute class prototypes from support set"""
embeddings = self.encoder(support_x, training=True)
prototypes = []

for i in range(n_way):
class_mask = tf.cast(tf.equal(support_y, i), tf.float32)
class_mask = tf.reshape(class_mask, (-1, 1))
class_embeddings = embeddings * class_mask
class_sum = tf.reduce_sum(class_embeddings, axis=0)
class_count = tf.reduce_sum(class_mask)
prototype = class_sum / class_count

return tf.stack(prototypes)

def train_step(self, support_x, support_y, query_x, query_y, n_way):
"""Single training step for prototypical network"""
with tf.GradientTape() as tape:
# Compute prototypes
prototypes = self.compute_prototypes(support_x, support_y, n_way)

# Get query embeddings
query_embeddings = self.encoder(query_x, training=True)

# Calculate squared distances to prototypes
expanded_query = tf.expand_dims(query_embeddings, axis=1) # [queries, 1, dim]
expanded_protos = tf.expand_dims(prototypes, axis=0) # [1, classes, dim]
distances = tf.reduce_sum(tf.square(expanded_query - expanded_protos), axis=2)

# Convert distances to probabilities (negative distance for similarity)
logits = -distances
loss = tf.reduce_mean(
query_y, tf.nn.softmax(logits), from_logits=False

gradients = tape.gradient(loss, self.encoder.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))

return loss


In this tutorial, we've explored meta-learning in TensorFlow, specifically:

  1. Core concepts of meta-learning including few-shot learning and task adaptation
  2. MAML (Model-Agnostic Meta-Learning) implementation for quick adaptation to new tasks
  3. Data preparation techniques for meta-learning with the MNIST and Omniglot datasets
  4. Prototypical Networks as an alternative meta-learning approach
  5. Real-world applications for image classification with limited data

Meta-learning is a powerful paradigm that enables models to learn efficiently from small amounts of data, making it valuable for applications where collecting large datasets is impractical or impossible.

Additional Resources and Exercises



  1. Few-Shot Classification: Extend the MAML implementation to work on your own dataset for few-shot classification.

  2. Hyperparameter Tuning: Experiment with different inner loop and meta learning rates to see how they affect adaptation performance.

  3. Meta-Reinforcement Learning: Adapt the MAML algorithm to work with reinforcement learning tasks.

  4. Domain Adaptation: Use meta-learning to build a model that can quickly adapt to different domains (e.g., different image styles).

  5. Memory-Augmented Meta-Learning: Implement a memory component to store information about previous tasks and improve adaptation.

  6. Visualizations: Create visualizations of how the feature embeddings evolve during meta-training.

Meta-learning is an active area of research with many exciting applications. As you develop your skills, you'll discover new ways to leverage these techniques for solving complex problems with limited data.

