PyTorch VAEs
Introduction
Variational Autoencoders (VAEs) are a powerful class of generative models that allow us to not only compress data but also generate new samples that resemble our training data. Unlike regular autoencoders which focus solely on reconstruction, VAEs add a probabilistic twist that makes them true generative models.
In this tutorial, we'll dive into:
- The theory behind VAEs
- The key components that make them work
- How to implement VAEs in PyTorch
- Applications and real-world examples
By the end, you'll understand how VAEs differ from traditional autoencoders and be able to create your own generative models for tasks like image synthesis, data augmentation, and more.
Understanding Variational Autoencoders
From Autoencoders to VAEs
Before diving into VAEs, let's quickly recap what an autoencoder does:
- An encoder compresses the input into a lower-dimensional latent space
- A decoder reconstructs the input from this compressed representation
A traditional autoencoder learns a fixed encoding for each input. In contrast, a VAE encodes inputs as probability distributions in the latent space, making them truly generative.
The VAE Architecture
A VAE consists of the following components:
- Encoder: Maps input to parameters of a probability distribution (typically mean and variance of a Gaussian)
- Sampling layer: Draws samples from this distribution using the reparameterization trick
- Decoder: Reconstructs the input from the sampled latent vector
The magic of VAEs comes from two key innovations:
- Representing the latent space as a probability distribution
- A special loss function that balances reconstruction quality with the structure of the latent space
Implementing a VAE in PyTorch
Let's start by implementing a basic VAE for the MNIST dataset:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# Set random seed for reproducibility
torch.manual_seed(42)
Defining the VAE Model
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc_mu(h1)
log_var = self.fc_var(h1)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, log_var = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
Training the VAE
Now, let's set up the training loop with our specialized loss function:
# VAE loss function
def loss_function(recon_x, x, mu, log_var):
# Reconstruction loss (binary cross entropy for MNIST)
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# KL divergence
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
# Training function
def train(model, device, train_loader, optimizer, epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, log_var = model(data)
loss = loss_function(recon_batch, data, mu, log_var)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] '
f'Loss: {loss.item() / len(data):.6f}')
print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')
Setting Up the Data
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# MNIST Dataset
transform = transforms.Compose([
transforms.ToTensor()
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
# Data loader
batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
Training Our Model
# Initialize model and optimizer
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Number of epochs to train for
epochs = 10
# Train the model
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer, epoch)
The Reparameterization Trick Explained
One of the key innovations in VAEs is the reparameterization trick. This allows us to backpropagate through the random sampling process.
Instead of directly sampling from a distribution parameterized by the encoder's output, which would break backpropagation, we:
- Output mean (μ) and log-variance (log σ²) from the encoder
- Sample a standard normal random variable (ε)
- Compute our latent variable as z = μ + σ * ε
This clever trick allows gradients to flow through the sampling operation, making end-to-end training possible.
Visualizing the Results
Let's visualize both the reconstructions and new samples generated by our VAE:
def visualize_reconstructions(model, data_loader, device):
model.eval()
with torch.no_grad():
# Get a batch of test data
data, _ = next(iter(data_loader))
data = data.to(device)
# Reconstruct the data
recon_data, _, _ = model(data)
# Plot original and reconstructed images
plt.figure(figsize=(10, 4))
for i in range(10):
# Original images
plt.subplot(2, 10, i + 1)
plt.imshow(data[i].cpu().numpy().reshape(28, 28), cmap='gray')
plt.axis('off')
# Reconstructed images
plt.subplot(2, 10, i + 11)
plt.imshow(recon_data[i].cpu().numpy().reshape(28, 28), cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
def generate_samples(model, device, num_samples=10):
model.eval()
with torch.no_grad():
# Sample from standard normal distribution
z = torch.randn(num_samples, 20).to(device)
# Decode the samples
samples = model.decode(z)
# Plot generated samples
plt.figure(figsize=(10, 1))
for i in range(num_samples):
plt.subplot(1, 10, i + 1)
plt.imshow(samples[i].cpu().numpy().reshape(28, 28), cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
# Visualize reconstructions
visualize_reconstructions(model, test_loader, device)
# Generate new samples
generate_samples(model, device)
Going Beyond: Convolutional VAEs
For image data, convolutional architectures often perform better. Here's a convolutional VAE implementation:
class ConvVAE(nn.Module):
def __init__(self, latent_dim=20):
super(ConvVAE, self).__init__()
# Encoder
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_var = nn.Linear(64 * 7 * 7, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim, 64 * 7 * 7)
self.deconv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.deconv2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
def encode(self, x):
h1 = F.relu(self.conv1(x))
h2 = F.relu(self.conv2(h1))
h2_flat = h2.view(-1, 64 * 7 * 7)
mu = self.fc_mu(h2_flat)
log_var = self.fc_var(h2_flat)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h3 = F.relu(self.fc3(z))
h3 = h3.view(-1, 64, 7, 7)
h4 = F.relu(self.deconv1(h3))
reconstruction = torch.sigmoid(self.deconv2(h4))
return reconstruction
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
Real-World Applications of VAEs
VAEs have found applications in numerous domains:
1. Image Generation
VAEs can generate new images similar to those they were trained on. This is useful for:
- Creating synthetic datasets for training other models
- Data augmentation when real data is limited
- Creative applications like art generation
2. Anomaly Detection
Since VAEs learn the distribution of normal data, they can identify outliers:
def detect_anomalies(model, data, threshold=100):
model.eval()
with torch.no_grad():
# Forward pass
recon, mu, log_var = model(data)
# Calculate reconstruction error
recon_error = F.mse_loss(recon, data.view(-1, 784), reduction='none')
recon_error = recon_error.sum(dim=1)
# Detect anomalies
anomalies = recon_error > threshold
return anomalies, recon_error
3. Data Compression
VAEs can be used for lossy compression of images or other high-dimensional data:
def compress_data(encoder_model, data):
encoder_model.eval()
with torch.no_grad():
# Encode data
mu, log_var = encoder_model.encode(data.view(-1, 784))
z = encoder_model.reparameterize(mu, log_var)
return z # Compressed representation
def decompress_data(decoder_model, z):
decoder_model.eval()
with torch.no_grad():
# Decode compressed representation
return decoder_model.decode(z)
4. Drug Discovery
VAEs can be used to generate new molecular structures with desired properties:
# Conceptual example (not runnable)
class MolecularVAE(nn.Module):
# Encodes and decodes molecular representations
def encode(self, molecular_representation):
# Convert molecular representation to latent space
pass
def decode(self, z):
# Convert latent vector back to molecular representation
pass
def generate_molecules_with_properties(self, target_properties):
# Sample latent space and filter based on properties
pass
Advanced Concepts in VAEs
Conditional VAEs
Conditional VAEs allow us to generate samples conditioned on a specific class or attribute:
class ConditionalVAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10):
super(ConditionalVAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim + num_classes, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim + num_classes, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
self.num_classes = num_classes
def encode(self, x, c):
# One-hot encode the class
c_one_hot = torch.zeros(c.size(0), self.num_classes)
c_one_hot.scatter_(1, c.unsqueeze(1), 1)
# Concatenate input and class
inputs = torch.cat([x, c_one_hot], 1)
h1 = F.relu(self.fc1(inputs))
mu = self.fc_mu(h1)
log_var = self.fc_var(h1)
return mu, log_var
def decode(self, z, c):
# One-hot encode the class
c_one_hot = torch.zeros(c.size(0), self.num_classes)
c_one_hot.scatter_(1, c.unsqueeze(1), 1)
# Concatenate latent vector and class
inputs = torch.cat([z, c_one_hot], 1)
h3 = F.relu(self.fc3(inputs))
return torch.sigmoid(self.fc4(h3))
def forward(self, x, c):
mu, log_var = self.encode(x.view(-1, 784), c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
β-VAE: Controlling Disentanglement
The β-VAE is a variant that introduces a hyperparameter β to control the trade-off between reconstruction quality and latent space disentanglement:
# Modified loss function for β-VAE
def beta_vae_loss(recon_x, x, mu, log_var, beta=4.0):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Beta scales the KL divergence term
return BCE + beta * KLD
Summary
In this tutorial, we:
- Learned about the theory behind Variational Autoencoders
- Implemented a basic VAE in PyTorch for the MNIST dataset
- Extended our implementation to convolutional architectures
- Explored real-world applications like image generation and anomaly detection
- Touched on advanced VAE variants like Conditional VAEs and β-VAEs
VAEs represent a powerful framework in generative modeling that bridges traditional autoencoders and generative models. They allow us to both compress data and generate new samples, making them versatile tools in the deep learning toolkit.
Further Resources and Exercises
Resources
- Original VAE Paper by Kingma and Welling
- Tutorial on Variational Autoencoders
- PyTorch VAE Examples
- β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework
Exercises
-
Latent Space Exploration: Implement a visualization of the latent space for MNIST digits. Try to find meaningful directions in the latent space.
-
Image Morphing: Generate a sequence of images that morphs one digit into another by linearly interpolating in the latent space.
-
Conditional Generation: Implement a conditional VAE that can generate specific MNIST digits based on a class label.
-
Colored Image VAE: Extend the ConvVAE to work with colored images like CIFAR-10.
-
Anomaly Detection: Create an anomaly detection system using VAEs to identify unusual images in a dataset.
Happy coding and exploring the generative world of VAEs!
If you spot any mistakes on this website, please let me know at feedback@compilenrun.com. I’d greatly appreciate your feedback! :)