Skip to main content

TensorFlow Self-Supervised Learning

Introduction

Self-supervised learning represents one of the most exciting advancements in modern deep learning. Unlike supervised learning, which relies heavily on labeled data, self-supervised learning allows models to learn meaningful representations from unlabeled data by creating "supervision" from the data itself. This approach has revolutionized how we train models in scenarios where labeled data is scarce or expensive to obtain.

In this tutorial, we'll explore how to implement self-supervised learning techniques using TensorFlow. We'll cover both the theoretical foundations and practical implementations, with a focus on making these advanced concepts accessible to beginners.

Why Self-Supervised Learning?

Before diving into implementation, let's understand why self-supervised learning is so valuable:

  1. Reduces dependency on labeled data: Creating large labeled datasets is expensive and time-consuming
  2. Improves generalization: Models learn more robust features that transfer well to downstream tasks
  3. Works with abundant unlabeled data: Takes advantage of the vast amounts of unlabeled data available in the world

Core Concepts

Self-supervised learning works by creating artificial supervised tasks from unlabeled data. These tasks (often called "pretext tasks") force the model to learn useful representations. Later, these representations can be fine-tuned for specific downstream tasks with minimal labeled data.

Common approaches include:

  1. Contrastive learning: Learn representations by comparing similar vs. dissimilar examples
  2. Masked prediction: Predict missing or masked parts of the input data
  3. Generative modeling: Generate or reconstruct data from partial observations

Implementing a Simple Self-Supervised Model in TensorFlow

Let's start with a basic contrastive learning example. We'll create a model that learns to distinguish between similar and dissimilar images without using labels.

1. Setting Up the Environment

python
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers, models

2. Data Preparation with Augmentations

The key to effective self-supervised learning is creating meaningful data transformations:

python
def get_augmented_pairs(x):
# Create two versions of the same image with different augmentations
image = tf.cast(x['image'], tf.float32) / 255.0

# First augmented version
aug1 = tf.image.random_crop(image, (24, 24, 3))
aug1 = tf.image.random_flip_left_right(aug1)
aug1 = tf.image.random_brightness(aug1, 0.2)

# Second augmented version (slightly different transformations)
aug2 = tf.image.random_crop(image, (24, 24, 3))
aug2 = tf.image.random_flip_left_right(aug2)
aug2 = tf.image.random_saturation(aug2, 0.8, 1.2)

return aug1, aug2

# Load CIFAR-10 dataset (without using labels)
dataset = tfds.load('cifar10', as_supervised=False, split='train')
dataset = dataset.map(lambda x: {'image': x['image']})
dataset = dataset.shuffle(10000).batch(256)

# Create pairs dataset
pairs_dataset = dataset.map(lambda x: get_augmented_pairs(x))

3. Building an Encoder Network

python
def create_encoder():
inputs = layers.Input((24, 24, 3))

# Simple convolutional encoder
x = layers.Conv2D(32, 3, activation='relu', padding='same')(inputs)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(64, 3, activation='relu', padding='same')(x)
x = layers.MaxPooling2D()(x)
x = layers.Conv2D(128, 3, activation='relu', padding='same')(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(128)(x) # Embedding dimension = 128
outputs = tf.math.l2_normalize(x, axis=1) # Normalize embeddings

return models.Model(inputs, outputs, name='encoder')

encoder = create_encoder()

4. Implementing the Contrastive Loss Function

The NT-Xent (Normalized Temperature-scaled Cross Entropy) loss is commonly used in contrastive learning:

python
def contrastive_loss(projections_1, projections_2, temperature=0.5):
# Cosine similarity between all possible pairs
batch_size = tf.shape(projections_1)[0]

# Concatenate the two sets of projections
all_projections = tf.concat([projections_1, projections_2], axis=0)

# Calculate cosine similarity
similarity_matrix = tf.matmul(all_projections, all_projections, transpose_b=True)

# Create labels: positives are on the diagonal blocks
positives = tf.concat([
tf.range(batch_size, 2*batch_size),
tf.range(batch_size)
], axis=0)

# Mask to filter out the positive examples
mask = tf.one_hot(positives, 2*batch_size)

# Apply temperature scaling
similarity_matrix /= temperature

# Calculate the contrastive loss
similarity_matrix_exp = tf.exp(similarity_matrix)

# Create a mask to exclude self-similarity
self_mask = tf.eye(2*batch_size)
similarity_matrix_exp = similarity_matrix_exp * (1 - self_mask)

# Calculate the denominator (sum of all exp similarities)
denominator = tf.reduce_sum(similarity_matrix_exp, axis=1, keepdims=True)

# Calculate the positive similarity scores
positive_similarity = tf.reduce_sum(similarity_matrix * mask, axis=1)

# Final loss
loss = -tf.reduce_mean(positive_similarity - tf.math.log(denominator))

return loss

5. Training Loop

python
# Create encoder model
encoder = create_encoder()

# Adam optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Training function
@tf.function
def train_step(images_1, images_2):
with tf.GradientTape() as tape:
# Get embeddings
embeddings_1 = encoder(images_1, training=True)
embeddings_2 = encoder(images_2, training=True)

# Calculate loss
loss = contrastive_loss(embeddings_1, embeddings_2)

# Calculate gradients and update weights
gradients = tape.gradient(loss, encoder.trainable_variables)
optimizer.apply_gradients(zip(gradients, encoder.trainable_variables))

return loss

# Training loop
epochs = 10
for epoch in range(epochs):
total_loss = 0
num_batches = 0

for aug1, aug2 in pairs_dataset:
batch_loss = train_step(aug1, aug2)
total_loss += batch_loss
num_batches += 1

avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

# Save the trained encoder
encoder.save('self_supervised_encoder')

6. Visualizing Learned Representations

Let's visualize the embeddings after training to see if our model has learned meaningful representations:

python
# Fetch a batch of images and generate embeddings
test_dataset = tfds.load('cifar10', as_supervised=True, split='test')
test_dataset = test_dataset.batch(1000).take(1)

for images, labels in test_dataset:
# Resize to 24x24 to match our model
images = tf.image.resize(images, (24, 24))
images = tf.cast(images, tf.float32) / 255.0

# Generate embeddings
embeddings = encoder(images, training=False)

# Use t-SNE to visualize in 2D
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Convert embeddings to numpy
embeddings_np = embeddings.numpy()
labels_np = labels.numpy()

# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42)
embeddings_2d = tsne.fit_transform(embeddings_np)

# Plot with colors based on the class labels
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels_np, cmap='tab10')
plt.colorbar(scatter, ticks=range(10))
plt.title('t-SNE Visualization of Self-Supervised Embeddings')
plt.savefig('embeddings_visualization.png')
plt.show()

Real-World Application: Transfer Learning with Self-Supervised Models

One of the main benefits of self-supervised learning is transfer learning to downstream tasks. Let's see how to use our pre-trained encoder for image classification with limited labeled data:

python
def create_classifier(encoder, num_classes=10):
# Freeze the encoder layers
encoder.trainable = False

# Build a classifier on top of the encoder
inputs = layers.Input((24, 24, 3))
x = encoder(inputs)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)

return models.Model(inputs, outputs)

# Create classifier using our pre-trained encoder
classifier = create_classifier(encoder)

# Prepare a small labeled dataset (simulating limited data scenario)
labeled_dataset = tfds.load('cifar10', as_supervised=True, split='train[:20%]')
labeled_dataset = labeled_dataset.map(
lambda x, y: (tf.cast(tf.image.resize(x, (24, 24)), tf.float32) / 255.0, y)
).batch(64)

test_dataset = tfds.load('cifar10', as_supervised=True, split='test')
test_dataset = test_dataset.map(
lambda x, y: (tf.cast(tf.image.resize(x, (24, 24)), tf.float32) / 255.0, y)
).batch(64)

# Compile the classifier
classifier.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Train the classifier
history = classifier.fit(
labeled_dataset,
epochs=5,
validation_data=test_dataset
)

# Evaluate the classifier
test_loss, test_acc = classifier.evaluate(test_dataset)
print(f"Test accuracy: {test_acc:.4f}")

Advanced Technique: SimCLR Implementation

Let's implement a simplified version of SimCLR (Simple Framework for Contrastive Learning of Visual Representations), one of the most popular self-supervised learning methods:

python
class SimCLR(tf.keras.Model):
def __init__(self, temperature=0.5):
super(SimCLR, self).__init__()
self.temperature = temperature
self.encoder = create_encoder()
self.projection_head = tf.keras.Sequential([
layers.Dense(128, activation='relu'),
layers.Dense(64)
])

def compile(self, optimizer, **kwargs):
super(SimCLR, self).compile(**kwargs)
self.optimizer = optimizer

def train_step(self, data):
# Unpack the data
view1, view2 = data

with tf.GradientTape() as tape:
# Forward pass through encoder and projection head
z1 = self.projection_head(self.encoder(view1, training=True))
z2 = self.projection_head(self.encoder(view2, training=True))

# Normalize projections
z1 = tf.math.l2_normalize(z1, axis=1)
z2 = tf.math.l2_normalize(z2, axis=1)

# Calculate contrastive loss
loss = contrastive_loss(z1, z2, self.temperature)

# Compute gradients and update weights
trainable_vars = self.encoder.trainable_variables + self.projection_head.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))

return {"loss": loss}

# Create and train SimCLR model
simclr_model = SimCLR()
simclr_model.compile(optimizer=tf.keras.optimizers.Adam(0.001))

# Train the model
history = simclr_model.fit(pairs_dataset, epochs=5)

# Save the encoder part for later use
simclr_model.encoder.save('simclr_encoder')

Self-Supervised Learning for Text Data

Self-supervised learning isn't limited to images. Let's implement a simple self-supervised approach for text data using TensorFlow:

python
# Define a simple masked language model
def create_text_encoder():
vocab_size = 10000 # Example vocabulary size
embedding_dim = 128

inputs = layers.Input(shape=(None,))
embedding = layers.Embedding(vocab_size, embedding_dim)(inputs)
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(embedding)
encoder_output = layers.GlobalAveragePooling1D()(x)

return models.Model(inputs, encoder_output)

def create_masked_language_model(encoder):
vocab_size = 10000 # Should match encoder

inputs = layers.Input(shape=(None,))
encoder_output = encoder(inputs)

# Add a prediction head for masked token prediction
outputs = layers.Dense(vocab_size, activation='softmax')(encoder_output)

return models.Model(inputs, outputs)

# Create a masked language modeling task
def create_masks(inputs, mask_rate=0.15):
# Create random masks
rand = tf.random.uniform(shape=tf.shape(inputs))
mask_indices = tf.where(rand < mask_rate)

# Create masked inputs (replace with token ID 1 for [MASK])
masked_inputs = tf.identity(inputs)
MASK_TOKEN_ID = 1
updates = tf.ones_like(mask_indices[:, 0], dtype=tf.int32) * MASK_TOKEN_ID
masked_inputs = tf.tensor_scatter_nd_update(masked_inputs, mask_indices, updates)

# Create targets (only predict the masked tokens)
target_mask = tf.zeros_like(inputs, dtype=tf.bool)
target_mask = tf.tensor_scatter_nd_update(target_mask, mask_indices, tf.ones_like(mask_indices[:, 0], dtype=tf.bool))
targets = tf.boolean_mask(inputs, target_mask)

return masked_inputs, targets, target_mask

# Example usage (conceptual - would need actual text data)
# text_encoder = create_text_encoder()
# masked_lm = create_masked_language_model(text_encoder)

Summary

In this tutorial, we've explored self-supervised learning using TensorFlow, focusing on:

  1. The basic concept of self-supervised learning and why it's important
  2. Contrastive learning as a practical approach for implementing self-supervised learning
  3. Building a simple contrastive learning model that learns from unlabeled images
  4. Transfer learning to leverage self-supervised representations for downstream tasks
  5. SimCLR implementation as a more advanced technique
  6. Text-based self-supervised learning concepts

Self-supervised learning represents a powerful paradigm that bridges the gap between unsupervised and supervised learning. By creating supervision signals from unlabeled data, we can train models that learn rich representations, reducing our dependence on expensive labeled datasets.

Additional Resources and Exercises

Resources

Exercises

  1. Modify the augmentation pipeline: Experiment with different image augmentations and observe their effect on representation quality.

  2. Evaluate transfer learning performance: Try fine-tuning your self-supervised model on different percentages of labeled data (1%, 5%, 10%) and compare with training from scratch.

  3. Implement MoCo: Try implementing another popular self-supervised learning framework like Momentum Contrast (MoCo).

  4. Text domain application: Implement a complete masked language modeling task using a real text dataset and evaluate its performance on a downstream task like sentiment analysis.

  5. Multi-modal self-supervised learning: Explore how to use self-supervised learning across multiple modalities (e.g., matching images with text descriptions).

By mastering these self-supervised learning techniques, you'll be well-equipped to tackle deep learning problems in data-limited scenarios, which are common in many real-world applications.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)