PyTorch Self-Supervised Learning
Self-supervised learning has emerged as a powerful paradigm in deep learning, enabling models to learn meaningful representations from unlabeled data. In this tutorial, we'll explore how to implement several self-supervised learning techniques using PyTorch.
Introduction to Self-Supervised Learning
Self-supervised learning is a subset of unsupervised learning where the model generates its own supervisory signal from the data itself. Instead of relying on human-annotated labels, the model learns to solve "pretext tasks" where the labels are derived automatically from the input data.
The core idea is to:
- Create a pretext task that requires understanding of the data's structure
- Train a model to solve this task
- Use the trained model's feature representations for downstream tasks
This approach is particularly valuable when labeled data is scarce or expensive to obtain.
Why Self-Supervised Learning?
- Reduces dependency on labeled data: Most deep learning applications require large amounts of labeled data, which can be expensive and time-consuming to collect.
- Learns general-purpose representations: Features learned through self-supervision often transfer well to multiple downstream tasks.
- Captures underlying data structure: Self-supervised models can identify patterns and relationships in data that might be missed in purely supervised approaches.
Common Self-Supervised Learning Techniques
Let's explore some popular self-supervised learning techniques with PyTorch implementations:
1. Simple Contrastive Learning
Contrastive learning trains models to bring similar samples closer in the embedding space while pushing dissimilar samples apart. One of the simplest implementations is to create different augmented views of the same image.
First, let's set up our imports:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
Now, let's define a simple contrastive loss function:
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=0.5):
super().__init__()
self.temperature = temperature
def forward(self, z_i, z_j):
"""
z_i, z_j: Batch of embeddings [batch_size, embedding_dim]
"""
batch_size = z_i.size(0)
# Normalize embeddings
z_i = F.normalize(z_i, dim=1)
z_j = F.normalize(z_j, dim=1)
# Concatenate representations
representations = torch.cat([z_i, z_j], dim=0)
# Calculate similarity matrix
similarity_matrix = F.cosine_similarity(
representations.unsqueeze(1),
representations.unsqueeze(0),
dim=2
)
# Remove diagonal from similarity comparison (self-similarity)
sim_ij = torch.diag(similarity_matrix, batch_size)
sim_ji = torch.diag(similarity_matrix, -batch_size)
positives = torch.cat([sim_ij, sim_ji], dim=0)
# Remove diagonal from all similarity scores
mask = ~torch.eye(2 * batch_size, device=z_i.device).bool()
negatives = similarity_matrix[mask].view(2 * batch_size, -1)
# Calculate NT-Xent loss
logits = torch.cat([positives.unsqueeze(1), negatives], dim=1) / self.temperature
labels = torch.zeros(2 * batch_size, device=z_i.device, dtype=torch.long)
return F.cross_entropy(logits, labels)
Now, let's implement a simple encoder network:
class EncoderNetwork(nn.Module):
def __init__(self, base_encoder="resnet18", projection_dim=128):
super().__init__()
# Load base encoder (e.g., ResNet)
if base_encoder == "resnet18":
self.encoder = torchvision.models.resnet18(pretrained=False)
feature_dim = self.encoder.fc.in_features
self.encoder.fc = nn.Identity() # Remove final FC layer
else:
raise NotImplementedError(f"Base encoder {base_encoder} not supported")
# Projection head (MLP)
self.projector = nn.Sequential(
nn.Linear(feature_dim, 512),
nn.ReLU(),
nn.Linear(512, projection_dim)
)
def forward(self, x):
h = self.encoder(x)
z = self.projector(h)
return z
Now, let's create a dataset that returns pairs of augmented views for each image:
class ContrastiveViewsDataset(Dataset):
def __init__(self, base_dataset, transform):
self.base_dataset = base_dataset
self.transform = transform
def __getitem__(self, idx):
img, label = self.base_dataset[idx]
# Apply two different augmentations
view1 = self.transform(img)
view2 = self.transform(img)
return view1, view2, label
def __len__(self):
return len(self.base_dataset)
Let's define our augmentation pipeline:
def get_transforms():
color_jitter = T.ColorJitter(0.4, 0.4, 0.4, 0.1)
train_transform = T.Compose([
T.RandomResizedCrop(size=96),
T.RandomHorizontalFlip(),
T.RandomApply([color_jitter], p=0.8),
T.RandomGrayscale(p=0.2),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return train_transform
Now, let's put everything together in a training loop:
def train_contrastive(encoder, train_loader, optimizer, criterion, device, epochs=30):
encoder.train()
for epoch in range(epochs):
running_loss = 0.0
for batch_idx, (view1, view2, _) in enumerate(train_loader):
view1, view2 = view1.to(device), view2.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
z_i = encoder(view1)
z_j = encoder(view2)
loss = criterion(z_i, z_j)
# Backward pass and optimize
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 99:
print(f'Epoch: {epoch+1}/{epochs}, Batch: {batch_idx+1}, Loss: {running_loss/100:.4f}')
running_loss = 0.0
print('Training complete.')
return encoder
Let's implement a full training example:
def main():
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load CIFAR-10 dataset
cifar10 = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=T.ToTensor()
)
# Create contrastive views dataset
contrastive_dataset = ContrastiveViewsDataset(
cifar10,
transform=get_transforms()
)
# Create data loader
train_loader = DataLoader(
contrastive_dataset,
batch_size=128,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True
)
# Initialize model
encoder = EncoderNetwork().to(device)
# Set up optimizer
optimizer = torch.optim.Adam(encoder.parameters(), lr=3e-4, weight_decay=1e-6)
# Set up loss function
criterion = ContrastiveLoss(temperature=0.5)
# Train the model
encoder = train_contrastive(
encoder=encoder,
train_loader=train_loader,
optimizer=optimizer,
criterion=criterion,
device=device,
epochs=10
)
# Save the model
torch.save(encoder.state_dict(), 'contrastive_encoder.pth')
print("Model saved.")
if __name__ == "__main__":
main()
2. SimCLR: A Simple Framework for Contrastive Learning
SimCLR is one of the most influential contrastive learning methods. Let's implement a simplified version:
# We can reuse most of the code from above, but with specific SimCLR augmentations
def get_simclr_transforms():
color_jitter = T.ColorJitter(0.8, 0.8, 0.8, 0.2)
simclr_transform = T.Compose([
T.RandomResizedCrop(size=96),
T.RandomHorizontalFlip(),
T.RandomApply([color_jitter], p=0.8),
T.RandomGrayscale(p=0.2),
T.GaussianBlur(kernel_size=9), # SimCLR specifically uses blur as augmentation
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return simclr_transform
# Everything else remains very similar to our previous implementation
3. Autoencoding and Masked Image Modeling
Another approach to self-supervised learning is to train models to reconstruct images from partial observations:
class MaskedAutoencoder(nn.Module):
def __init__(self, mask_ratio=0.75):
super().__init__()
self.mask_ratio = mask_ratio
# Encoder (simplified)
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
# Decoder (simplified)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size=2, stride=2),
nn.Sigmoid()
)
def random_masking(self, x):
B, C, H, W = x.shape
# Create a mask for random patches
mask = torch.rand(B, 1, H//8, W//8, device=x.device)
mask = (mask > self.mask_ratio).float()
mask = mask.repeat_interleave(8, dim=2).repeat_interleave(8, dim=3)
# Apply mask
masked_x = x * mask
return masked_x, mask
def forward(self, x):
# Randomly mask the input
masked_x, mask = self.random_masking(x)
# Encode masked input
features = self.encoder(masked_x)
# Decode to reconstruct the original image
reconstructed = self.decoder(features)
return reconstructed, mask
# Let's define a training function
def train_autoencoder(model, train_loader, optimizer, device, epochs=30):
model.train()
criterion = nn.MSELoss()
for epoch in range(epochs):
running_loss = 0.0
for batch_idx, (images, _) in enumerate(train_loader):
images = images.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
reconstructed, mask = model(images)
# Calculate loss (only on the original pixel locations)
loss = criterion(reconstructed * mask, images * mask)
# Backward pass and optimize
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 99:
print(f'Epoch: {epoch+1}/{epochs}, Batch: {batch_idx+1}, Loss: {running_loss/100:.4f}')
running_loss = 0.0
print('Training complete.')
return model
Evaluating Self-Supervised Models
After training a self-supervised model, we typically evaluate its representation quality by:
- Linear Evaluation Protocol: Freeze the encoder and train a linear classifier on top of it
- Fine-tuning: Adapt the entire network for a downstream task
- Feature Visualization: Examine the learned feature space using dimensionality reduction techniques
Let's implement a linear evaluation example:
def linear_evaluation(encoder, train_loader, test_loader, device, num_classes=10):
# Freeze encoder weights
for param in encoder.parameters():
param.requires_grad = False
# Get the dimension of the encoder output
with torch.no_grad():
dummy_input = next(iter(train_loader))[0][:1].to(device)
feat_dim = encoder.encoder(dummy_input).shape[1]
# Create a linear classifier
classifier = nn.Linear(feat_dim, num_classes).to(device)
# Set up optimizer and criterion
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# Train the classifier
for epoch in range(20):
# Training
encoder.eval()
classifier.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
with torch.no_grad():
features = encoder.encoder(images)
outputs = classifier(features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
print(f"Epoch: {epoch+1}, Loss: {running_loss/len(train_loader):.4f}, "
f"Acc: {100.*correct/total:.2f}%")
# Evaluation
encoder.eval()
classifier.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
features = encoder.encoder(images)
outputs = classifier(features)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100. * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
return accuracy
Real-World Application: Self-Supervised Learning for Medical Imaging
Self-supervised learning is particularly valuable in domains like medical imaging, where labeled data is scarce but unlabeled data is plentiful.
Let's create a simple example for lung CT scan analysis:
class CTScanContrastiveDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __getitem__(self, idx):
img = plt.imread(self.image_paths[idx]) # This would be replaced with proper DICOM handling
if self.transform:
view1 = self.transform(img)
view2 = self.transform(img)
return view1, view2
else:
return img, img
def __len__(self):
return len(self.image_paths)
def ct_scan_transforms():
"""CT scan specific augmentations"""
transform = T.Compose([
T.ToTensor(),
T.RandomRotation(10),
T.RandomResizedCrop(size=224, scale=(0.8, 1.0)),
T.RandomHorizontalFlip(),
# CT scan specific intensity adjustments
T.Lambda(lambda x: x + 0.05 * torch.randn_like(x)), # Add noise
# Normalize to [0, 1] range which is typical for CT
T.Lambda(lambda x: torch.clamp(x, 0, 1))
])
return transform
def medical_imaging_example():
# In a real application, you would load actual CT scan paths
# image_paths = [...list of CT scan file paths...]
# Set up your dataset
dataset = CTScanContrastiveDataset(
image_paths=["path1.dcm", "path2.dcm"], # Placeholder
transform=ct_scan_transforms()
)
# Train using our contrastive learning approach from earlier
# ...
# After training, you can use the encoder for:
# 1. Classification of lung conditions
# 2. Segmentation of abnormalities
# 3. Similarity search for related cases
Summary
In this tutorial, we've explored self-supervised learning in PyTorch, focusing on:
- Contrastive learning techniques like SimCLR that learn representations by comparing different views of the same image
- Masked autoencoding that reconstructs images from partial observations
- Evaluation methods to assess the quality of learned representations
- Practical applications in domains like medical imaging
Self-supervised learning offers a powerful approach for leveraging unlabeled data to learn useful representations. As deep learning continues to evolve, these techniques are becoming increasingly important for building more efficient and data-efficient AI systems.
Additional Resources
- Self-Supervised Learning at PyTorch
- VISSL: A library for state-of-the-art self-supervised learning
- SimCLR Paper
- MAE Paper
Exercises
-
Basic Implementation: Modify the contrastive learning example to work with a different dataset like MNIST or Fashion-MNIST.
-
Advanced Techniques: Implement MoCo (Momentum Contrast) or BYOL (Bootstrap Your Own Latent) self-supervised learning methods in PyTorch.
-
Transfer Learning: Use a pre-trained self-supervised model as initialization for a downstream task like image classification on a small dataset.
-
Data Visualization: Implement t-SNE or UMAP visualization to explore the learned feature space of your self-supervised model.
-
Domain Adaptation: Apply self-supervised learning to a domain adaptation problem, where you train on one dataset but need to perform well on a different but related dataset.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)