TensorFlow Meta-Learning
Introduction
Meta-learning, often described as "learning to learn," represents an advanced paradigm in machine learning where systems improve their learning abilities over time and across different tasks. Unlike traditional machine learning approaches that start from scratch for each new task, meta-learning leverages experience from previous learning tasks to rapidly adapt to new, similar tasks with minimal data.
In this tutorial, we'll explore how to implement meta-learning techniques using TensorFlow. This approach is particularly valuable for scenarios where you have limited data for new tasks or need models that can quickly adapt to changing environments.
Key Concepts in Meta-Learning
Before diving into code, let's understand some fundamental concepts:
- Few-shot Learning: Training models to recognize patterns from just a few examples
- Model-Agnostic Meta-Learning (MAML): A popular meta-learning algorithm that optimizes for quick adaptation
- Learning to Learn: The process where a model improves its learning algorithm through experience
- Support and Query Sets: Terminology used in meta-learning for training and testing samples
Implementing a Simple Meta-Learning Framework
Let's start by setting up a basic meta-learning framework using TensorFlow:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, models
# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)
Creating a Meta-Model Architecture
We'll first define a simple neural network that will serve as our base model:
def create_model():
    model = models.Sequential([
        layers.Dense(64, activation='relu', input_shape=(28*28,)),
        layers.Dense(64, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])
    return model
Implementing MAML (Model-Agnostic Meta-Learning)
Now, let's implement a simplified version of MAML, one of the most popular meta-learning algorithms:
class MAML:
    def __init__(self, model_fn, inner_lr=0.01, meta_lr=0.001):
        self.model_fn = model_fn
        self.meta_model = model_fn()
        self.inner_lr = inner_lr
        self.meta_optimizer = tf.keras.optimizers.Adam(meta_lr)
        
    def adapt(self, support_x, support_y, num_inner_steps=1):
        """Adapt model parameters to a new task based on support examples"""
        adapted_model = self.model_fn()
        adapted_model.set_weights(self.meta_model.get_weights())
        
        # Inner loop optimization
        for _ in range(num_inner_steps):
            with tf.GradientTape() as tape:
                predictions = adapted_model(support_x, training=True)
                loss = tf.keras.losses.sparse_categorical_crossentropy(support_y, predictions)
                loss = tf.reduce_mean(loss)
            
            # Compute gradients and update the model's parameters
            gradients = tape.gradient(loss, adapted_model.trainable_variables)
            for i, var in enumerate(adapted_model.trainable_variables):
                var.assign_sub(self.inner_lr * gradients[i])
                
        return adapted_model
    
    @tf.function
    def meta_train_step(self, tasks):
        """Perform a meta-training step on a batch of tasks"""
        meta_loss = 0
        meta_gradients = [tf.zeros_like(var) for var in self.meta_model.trainable_variables]
        
        with tf.GradientTape() as meta_tape:
            for task in tasks:
                support_x, support_y, query_x, query_y = task
                
                # Adapt model to support set
                adapted_model = self.adapt(support_x, support_y)
                
                # Evaluate on query set
                query_predictions = adapted_model(query_x, training=True)
                query_loss = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(query_y, query_predictions)
                )
                
                meta_loss += query_loss / len(tasks)
        
        # Compute meta-gradients
        meta_gradients = meta_tape.gradient(meta_loss, self.meta_model.trainable_variables)
        
        # Apply meta-gradients
        self.meta_optimizer.apply_gradients(zip(meta_gradients, self.meta_model.trainable_variables))
        
        return meta_loss
Preparing Data for Meta-Learning
For meta-learning, we need to structure our data differently. Let's prepare the MNIST dataset as a collection of tasks:
def prepare_mnist_tasks(n_way=5, k_shot=1, n_query=5, n_tasks=100):
    """
    Prepare meta-learning tasks from MNIST dataset
    n_way: Number of classes per task
    k_shot: Number of support examples per class
    n_query: Number of query examples per class
    n_tasks: Total number of tasks to generate
    """
    # Load and preprocess MNIST
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    x_train = x_train.reshape(-1, 28*28).astype('float32') / 255.0
    
    # Group examples by class
    class_examples = [x_train[y_train == i] for i in range(10)]
    
    tasks = []
    for _ in range(n_tasks):
        # Randomly select n_way classes
        selected_classes = np.random.choice(10, n_way, replace=False)
        
        support_x = []
        support_y = []
        query_x = []
        query_y = []
        
        for i, cls in enumerate(selected_classes):
            # Select k_shot + n_query examples from this class
            examples = class_examples[cls]
            indices = np.random.choice(len(examples), k_shot + n_query, replace=False)
            
            # Split into support and query sets
            support_indices = indices[:k_shot]
            query_indices = indices[k_shot:k_shot + n_query]
            
            support_x.append(examples[support_indices])
            support_y.append(np.ones(k_shot) * i)  # Relabel as 0, 1, 2, ...
            
            query_x.append(examples[query_indices])
            query_y.append(np.ones(n_query) * i)
        
        # Combine and shuffle support set
        support_x = np.vstack(support_x)
        support_y = np.concatenate(support_y)
        perm = np.random.permutation(len(support_y))
        support_x = support_x[perm]
        support_y = support_y[perm]
        
        # Combine and shuffle query set
        query_x = np.vstack(query_x)
        query_y = np.concatenate(query_y)
        perm = np.random.permutation(len(query_y))
        query_x = query_x[perm]
        query_y = query_y[perm]
        
        tasks.append((support_x, support_y, query_x, query_y))
    
    return tasks
Training and Evaluating a Meta-Learning Model
Now let's train our meta-learning framework:
# Prepare meta-training tasks
meta_train_tasks = prepare_mnist_tasks(n_way=5, k_shot=5, n_query=10, n_tasks=1000)
meta_test_tasks = prepare_mnist_tasks(n_way=5, k_shot=5, n_query=10, n_tasks=100)
# Initialize MAML
maml = MAML(create_model, inner_lr=0.05, meta_lr=0.001)
# Meta-training loop
batch_size = 4
n_epochs = 5
for epoch in range(n_epochs):
    # Shuffle tasks
    np.random.shuffle(meta_train_tasks)
    
    total_loss = 0
    n_batches = len(meta_train_tasks) // batch_size
    
    for i in range(n_batches):
        task_batch = meta_train_tasks[i * batch_size:(i + 1) * batch_size]
        loss = maml.meta_train_step(task_batch)
        total_loss += loss
        
        if i % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {i}/{n_batches}, Loss: {loss:.4f}")
    
    avg_loss = total_loss / n_batches
    print(f"Epoch {epoch+1} completed, Average Loss: {avg_loss:.4f}")
# Meta-testing
test_accuracies = []
for task in meta_test_tasks:
    support_x, support_y, query_x, query_y = task
    
    # Adapt model to support set (5 inner steps for better adaptation)
    adapted_model = maml.adapt(support_x, support_y, num_inner_steps=5)
    
    # Evaluate on query set
    query_predictions = adapted_model(query_x)
    query_predicted_classes = tf.argmax(query_predictions, axis=1)
    
    accuracy = tf.reduce_mean(
        tf.cast(tf.equal(query_predicted_classes, tf.cast(query_y, tf.int64)), tf.float32)
    )
    test_accuracies.append(accuracy)
mean_accuracy = tf.reduce_mean(test_accuracies)
print(f"Meta-test accuracy: {mean_accuracy:.4f}")
Real-World Application: Few-Shot Image Classification
Let's implement a more practical example using meta-learning for few-shot image classification with the Omniglot dataset, which is commonly used for meta-learning benchmarks:
# First, let's download and prepare the Omniglot dataset
import tensorflow_datasets as tfds
# Load Omniglot dataset
omniglot_ds = tfds.load('omniglot', split='train+test', as_supervised=True)
def prepare_omniglot_tasks(dataset, n_way=5, k_shot=1, n_query=5, n_tasks=100):
    """Prepare meta-learning tasks from Omniglot dataset"""
    # Process dataset
    images = []
    labels = []
    
    for img, label in dataset:
        images.append(tf.image.resize(img, (28, 28)))
        labels.append(label)
    
    images = tf.stack(images) / 255.0  # Normalize
    images = tf.reshape(images, (-1, 28*28))
    labels = tf.stack(labels)
    
    # Group by class
    unique_labels = tf.unique(labels)[0]
    n_classes = len(unique_labels)
    
    class_examples = [images[labels == label] for label in unique_labels]
    class_sizes = [len(examples) for examples in class_examples]
    
    # Generate tasks
    tasks = []
    for _ in range(n_tasks):
        # Select classes with enough examples
        valid_classes = [i for i, size in enumerate(class_sizes) if size >= k_shot + n_query]
        if len(valid_classes) < n_way:
            continue
            
        selected_classes = np.random.choice(valid_classes, n_way, replace=False)
        
        support_x = []
        support_y = []
        query_x = []
        query_y = []
        
        for i, cls_idx in enumerate(selected_classes):
            examples = class_examples[cls_idx].numpy()
            indices = np.random.choice(len(examples), k_shot + n_query, replace=False)
            
            support_indices = indices[:k_shot]
            query_indices = indices[k_shot:k_shot + n_query]
            
            support_x.append(examples[support_indices])
            support_y.append(np.ones(k_shot) * i)
            
            query_x.append(examples[query_indices])
            query_y.append(np.ones(n_query) * i)
        
        support_x = np.vstack(support_x)
        support_y = np.concatenate(support_y)
        query_x = np.vstack(query_x)
        query_y = np.concatenate(query_y)
        
        tasks.append((support_x, support_y, query_x, query_y))
    
    return tasks
Let's create a CNN model specifically for Omniglot:
def create_omniglot_model():
    model = models.Sequential([
        layers.Reshape((28, 28, 1), input_shape=(28*28,)),
        layers.Conv2D(32, 3, activation='relu', padding='same'),
        layers.MaxPooling2D(),
        layers.Conv2D(64, 3, activation='relu', padding='same'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(5, activation='softmax')  # 5-way classification
    ])
    return model
# Create MAML instance with omniglot model
omniglot_maml = MAML(create_omniglot_model, inner_lr=0.1, meta_lr=0.001)
# Prepare tasks and train (code would be similar to MNIST example above)
Prototypical Networks: Another Meta-Learning Approach
MAML is just one approach to meta-learning. Let's implement another popular technique called Prototypical Networks:
class ProtoNet:
    def __init__(self, input_shape):
        self.encoder = models.Sequential([
            layers.Reshape((28, 28, 1), input_shape=(28*28,)),
            layers.Conv2D(64, 3, activation='relu', padding='same'),
            layers.MaxPooling2D(),
            layers.Conv2D(64, 3, activation='relu', padding='same'),
            layers.MaxPooling2D(),
            layers.Conv2D(64, 3, activation='relu', padding='same'),
            layers.MaxPooling2D(),
            layers.Flatten(),
            layers.Dense(64)  # Embedding dimension
        ])
        
        self.optimizer = tf.keras.optimizers.Adam(0.001)
    
    def compute_prototypes(self, support_x, support_y, n_way):
        """Compute class prototypes from support set"""
        embeddings = self.encoder(support_x, training=True)
        prototypes = []
        
        for i in range(n_way):
            class_mask = tf.cast(tf.equal(support_y, i), tf.float32)
            class_mask = tf.reshape(class_mask, (-1, 1))
            class_embeddings = embeddings * class_mask
            class_sum = tf.reduce_sum(class_embeddings, axis=0)
            class_count = tf.reduce_sum(class_mask)
            prototype = class_sum / class_count
            prototypes.append(prototype)
            
        return tf.stack(prototypes)
    
    @tf.function
    def train_step(self, support_x, support_y, query_x, query_y, n_way):
        """Single training step for prototypical network"""
        with tf.GradientTape() as tape:
            # Compute prototypes
            prototypes = self.compute_prototypes(support_x, support_y, n_way)
            
            # Get query embeddings
            query_embeddings = self.encoder(query_x, training=True)
            
            # Calculate squared distances to prototypes
            expanded_query = tf.expand_dims(query_embeddings, axis=1)  # [queries, 1, dim]
            expanded_protos = tf.expand_dims(prototypes, axis=0)       # [1, classes, dim]
            distances = tf.reduce_sum(tf.square(expanded_query - expanded_protos), axis=2)
            
            # Convert distances to probabilities (negative distance for similarity)
            logits = -distances
            loss = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(
                    query_y, tf.nn.softmax(logits), from_logits=False
                )
            )
            
        gradients = tape.gradient(loss, self.encoder.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables))
        
        return loss
Summary
In this tutorial, we've explored meta-learning in TensorFlow, specifically:
- Core concepts of meta-learning including few-shot learning and task adaptation
- MAML (Model-Agnostic Meta-Learning) implementation for quick adaptation to new tasks
- Data preparation techniques for meta-learning with the MNIST and Omniglot datasets
- Prototypical Networks as an alternative meta-learning approach
- Real-world applications for image classification with limited data
Meta-learning is a powerful paradigm that enables models to learn efficiently from small amounts of data, making it valuable for applications where collecting large datasets is impractical or impossible.
Additional Resources and Exercises
Resources
- TensorFlow Model Garden for more advanced implementations
- Meta-Learning: Learning to Learn Fast - comprehensive blog post
- Chelsea Finn's MAML paper - the original MAML research paper
Exercises
- 
Few-Shot Classification: Extend the MAML implementation to work on your own dataset for few-shot classification. 
- 
Hyperparameter Tuning: Experiment with different inner loop and meta learning rates to see how they affect adaptation performance. 
- 
Meta-Reinforcement Learning: Adapt the MAML algorithm to work with reinforcement learning tasks. 
- 
Domain Adaptation: Use meta-learning to build a model that can quickly adapt to different domains (e.g., different image styles). 
- 
Memory-Augmented Meta-Learning: Implement a memory component to store information about previous tasks and improve adaptation. 
- 
Visualizations: Create visualizations of how the feature embeddings evolve during meta-training. 
Meta-learning is an active area of research with many exciting applications. As you develop your skills, you'll discover new ways to leverage these techniques for solving complex problems with limited data.
💡 Found a typo or mistake? Click "Edit this page" to suggest a correction. Your feedback is greatly appreciated!