Skip to main content

PyTorch VAEs

Introduction

Variational Autoencoders (VAEs) are a powerful class of generative models that allow us to not only compress data but also generate new samples that resemble our training data. Unlike regular autoencoders which focus solely on reconstruction, VAEs add a probabilistic twist that makes them true generative models.

In this tutorial, we'll dive into:

  • The theory behind VAEs
  • The key components that make them work
  • How to implement VAEs in PyTorch
  • Applications and real-world examples

By the end, you'll understand how VAEs differ from traditional autoencoders and be able to create your own generative models for tasks like image synthesis, data augmentation, and more.

Understanding Variational Autoencoders

From Autoencoders to VAEs

Before diving into VAEs, let's quickly recap what an autoencoder does:

  1. An encoder compresses the input into a lower-dimensional latent space
  2. A decoder reconstructs the input from this compressed representation

A traditional autoencoder learns a fixed encoding for each input. In contrast, a VAE encodes inputs as probability distributions in the latent space, making them truly generative.

The VAE Architecture

A VAE consists of the following components:

  1. Encoder: Maps input to parameters of a probability distribution (typically mean and variance of a Gaussian)
  2. Sampling layer: Draws samples from this distribution using the reparameterization trick
  3. Decoder: Reconstructs the input from the sampled latent vector

The magic of VAEs comes from two key innovations:

  • Representing the latent space as a probability distribution
  • A special loss function that balances reconstruction quality with the structure of the latent space

Implementing a VAE in PyTorch

Let's start by implementing a basic VAE for the MNIST dataset:

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Set random seed for reproducibility
torch.manual_seed(42)

Defining the VAE Model

python
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()

# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)

# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)

def encode(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc_mu(h1)
log_var = self.fc_var(h1)
return mu, log_var

def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + eps * std
return z

def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))

def forward(self, x):
mu, log_var = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var

Training the VAE

Now, let's set up the training loop with our specialized loss function:

python
# VAE loss function
def loss_function(recon_x, x, mu, log_var):
# Reconstruction loss (binary cross entropy for MNIST)
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

# KL divergence
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

return BCE + KLD

# Training function
def train(model, device, train_loader, optimizer, epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()

recon_batch, mu, log_var = model(data)
loss = loss_function(recon_batch, data, mu, log_var)

loss.backward()
train_loss += loss.item()
optimizer.step()

if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] '
f'Loss: {loss.item() / len(data):.6f}')

print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

Setting Up the Data

python
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST Dataset
transform = transforms.Compose([
transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

# Data loader
batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Training Our Model

python
# Initialize model and optimizer
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Number of epochs to train for
epochs = 10

# Train the model
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer, epoch)

The Reparameterization Trick Explained

One of the key innovations in VAEs is the reparameterization trick. This allows us to backpropagate through the random sampling process.

Instead of directly sampling from a distribution parameterized by the encoder's output, which would break backpropagation, we:

  1. Output mean (μ) and log-variance (log σ²) from the encoder
  2. Sample a standard normal random variable (ε)
  3. Compute our latent variable as z = μ + σ * ε

This clever trick allows gradients to flow through the sampling operation, making end-to-end training possible.

Visualizing the Results

Let's visualize both the reconstructions and new samples generated by our VAE:

python
def visualize_reconstructions(model, data_loader, device):
model.eval()
with torch.no_grad():
# Get a batch of test data
data, _ = next(iter(data_loader))
data = data.to(device)

# Reconstruct the data
recon_data, _, _ = model(data)

# Plot original and reconstructed images
plt.figure(figsize=(10, 4))
for i in range(10):
# Original images
plt.subplot(2, 10, i + 1)
plt.imshow(data[i].cpu().numpy().reshape(28, 28), cmap='gray')
plt.axis('off')

# Reconstructed images
plt.subplot(2, 10, i + 11)
plt.imshow(recon_data[i].cpu().numpy().reshape(28, 28), cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()

def generate_samples(model, device, num_samples=10):
model.eval()
with torch.no_grad():
# Sample from standard normal distribution
z = torch.randn(num_samples, 20).to(device)

# Decode the samples
samples = model.decode(z)

# Plot generated samples
plt.figure(figsize=(10, 1))
for i in range(num_samples):
plt.subplot(1, 10, i + 1)
plt.imshow(samples[i].cpu().numpy().reshape(28, 28), cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()

# Visualize reconstructions
visualize_reconstructions(model, test_loader, device)

# Generate new samples
generate_samples(model, device)

Going Beyond: Convolutional VAEs

For image data, convolutional architectures often perform better. Here's a convolutional VAE implementation:

python
class ConvVAE(nn.Module):
def __init__(self, latent_dim=20):
super(ConvVAE, self).__init__()

# Encoder
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_var = nn.Linear(64 * 7 * 7, latent_dim)

# Decoder
self.fc3 = nn.Linear(latent_dim, 64 * 7 * 7)
self.deconv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.deconv2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)

def encode(self, x):
h1 = F.relu(self.conv1(x))
h2 = F.relu(self.conv2(h1))
h2_flat = h2.view(-1, 64 * 7 * 7)

mu = self.fc_mu(h2_flat)
log_var = self.fc_var(h2_flat)
return mu, log_var

def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std

def decode(self, z):
h3 = F.relu(self.fc3(z))
h3 = h3.view(-1, 64, 7, 7)

h4 = F.relu(self.deconv1(h3))
reconstruction = torch.sigmoid(self.deconv2(h4))
return reconstruction

def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var

Real-World Applications of VAEs

VAEs have found applications in numerous domains:

1. Image Generation

VAEs can generate new images similar to those they were trained on. This is useful for:

  • Creating synthetic datasets for training other models
  • Data augmentation when real data is limited
  • Creative applications like art generation

2. Anomaly Detection

Since VAEs learn the distribution of normal data, they can identify outliers:

python
def detect_anomalies(model, data, threshold=100):
model.eval()
with torch.no_grad():
# Forward pass
recon, mu, log_var = model(data)

# Calculate reconstruction error
recon_error = F.mse_loss(recon, data.view(-1, 784), reduction='none')
recon_error = recon_error.sum(dim=1)

# Detect anomalies
anomalies = recon_error > threshold

return anomalies, recon_error

3. Data Compression

VAEs can be used for lossy compression of images or other high-dimensional data:

python
def compress_data(encoder_model, data):
encoder_model.eval()
with torch.no_grad():
# Encode data
mu, log_var = encoder_model.encode(data.view(-1, 784))
z = encoder_model.reparameterize(mu, log_var)
return z # Compressed representation

def decompress_data(decoder_model, z):
decoder_model.eval()
with torch.no_grad():
# Decode compressed representation
return decoder_model.decode(z)

4. Drug Discovery

VAEs can be used to generate new molecular structures with desired properties:

python
# Conceptual example (not runnable)
class MolecularVAE(nn.Module):
# Encodes and decodes molecular representations
def encode(self, molecular_representation):
# Convert molecular representation to latent space
pass

def decode(self, z):
# Convert latent vector back to molecular representation
pass

def generate_molecules_with_properties(self, target_properties):
# Sample latent space and filter based on properties
pass

Advanced Concepts in VAEs

Conditional VAEs

Conditional VAEs allow us to generate samples conditioned on a specific class or attribute:

python
class ConditionalVAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, num_classes=10):
super(ConditionalVAE, self).__init__()

# Encoder
self.fc1 = nn.Linear(input_dim + num_classes, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)

# Decoder
self.fc3 = nn.Linear(latent_dim + num_classes, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)

self.num_classes = num_classes

def encode(self, x, c):
# One-hot encode the class
c_one_hot = torch.zeros(c.size(0), self.num_classes)
c_one_hot.scatter_(1, c.unsqueeze(1), 1)

# Concatenate input and class
inputs = torch.cat([x, c_one_hot], 1)

h1 = F.relu(self.fc1(inputs))
mu = self.fc_mu(h1)
log_var = self.fc_var(h1)
return mu, log_var

def decode(self, z, c):
# One-hot encode the class
c_one_hot = torch.zeros(c.size(0), self.num_classes)
c_one_hot.scatter_(1, c.unsqueeze(1), 1)

# Concatenate latent vector and class
inputs = torch.cat([z, c_one_hot], 1)

h3 = F.relu(self.fc3(inputs))
return torch.sigmoid(self.fc4(h3))

def forward(self, x, c):
mu, log_var = self.encode(x.view(-1, 784), c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var

def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std

β-VAE: Controlling Disentanglement

The β-VAE is a variant that introduces a hyperparameter β to control the trade-off between reconstruction quality and latent space disentanglement:

python
# Modified loss function for β-VAE
def beta_vae_loss(recon_x, x, mu, log_var, beta=4.0):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

# Beta scales the KL divergence term
return BCE + beta * KLD

Summary

In this tutorial, we:

  • Learned about the theory behind Variational Autoencoders
  • Implemented a basic VAE in PyTorch for the MNIST dataset
  • Extended our implementation to convolutional architectures
  • Explored real-world applications like image generation and anomaly detection
  • Touched on advanced VAE variants like Conditional VAEs and β-VAEs

VAEs represent a powerful framework in generative modeling that bridges traditional autoencoders and generative models. They allow us to both compress data and generate new samples, making them versatile tools in the deep learning toolkit.

Further Resources and Exercises

Resources

  1. Original VAE Paper by Kingma and Welling
  2. Tutorial on Variational Autoencoders
  3. PyTorch VAE Examples
  4. β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework

Exercises

  1. Latent Space Exploration: Implement a visualization of the latent space for MNIST digits. Try to find meaningful directions in the latent space.

  2. Image Morphing: Generate a sequence of images that morphs one digit into another by linearly interpolating in the latent space.

  3. Conditional Generation: Implement a conditional VAE that can generate specific MNIST digits based on a class label.

  4. Colored Image VAE: Extend the ConvVAE to work with colored images like CIFAR-10.

  5. Anomaly Detection: Create an anomaly detection system using VAEs to identify unusual images in a dataset.

Happy coding and exploring the generative world of VAEs!



If you spot any mistakes on this website, please let me know at feedback@compilenrun.com. I’d greatly appreciate your feedback! :)