PyTorch Generative Models
Have you ever wondered how AI can generate realistic images of people who don't exist, create art in different styles, or transform photos from day to night? All these fascinating applications are powered by generative models - one of the most exciting areas in modern deep learning and computer vision.
In this tutorial, we'll dive into the world of generative models using PyTorch, exploring how they work and how you can build your own.
What Are Generative Models?
Generative models are a class of machine learning models that learn to generate new data samples that resemble a given training dataset. Unlike discriminative models that learn boundaries between classes, generative models learn the underlying distribution of the data itself.
In computer vision, generative models can:
- Generate new, realistic-looking images
- Transform images from one domain to another
- Complete or restore damaged images
- Generate variations of existing images
Types of Generative Models
We'll focus on three popular types of generative models:
- Variational Autoencoders (VAEs): Models that learn compressed representations of data
- Generative Adversarial Networks (GANs): Two-network systems that compete to generate realistic data
- Diffusion Models: Newer models that gradually add and remove noise from data
1. Variational Autoencoders (VAEs)
VAEs learn to compress data into a lower-dimensional latent space and then reconstruct it. What makes them "generative" is their ability to sample from this latent space to create new data.
How VAEs Work
A VAE consists of two main parts:
- An encoder that compresses input images into a latent space
- A decoder that reconstructs images from the latent space
The special part is that the latent space is constrained to follow a normal distribution, which allows us to sample new points and generate new images.
Implementing a Simple VAE in PyTorch
Let's create a basic VAE for the MNIST dataset:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# Set random seed for reproducibility
torch.manual_seed(42)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyperparameters
batch_size = 128
learning_rate = 1e-3
num_epochs = 20
latent_dim = 20
# Data loading
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# VAE Model
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc2 = nn.Linear(latent_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
h = F.relu(self.fc1(x))
mu = self.fc_mu(h)
log_var = self.fc_var(h)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def decode(self, z):
h = F.relu(self.fc2(z))
x_reconstructed = torch.sigmoid(self.fc3(h))
return x_reconstructed
def forward(self, x):
mu, log_var = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, log_var)
x_reconstructed = self.decode(z)
return x_reconstructed, mu, log_var
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Loss function
def loss_function(recon_x, x, mu, log_var):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
# Training loop
for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, log_var = model(data)
loss = loss_function(recon_batch, data, mu, log_var)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss/len(train_loader.dataset):.4f}')
# Generate new images
with torch.no_grad():
z = torch.randn(64, latent_dim).to(device)
sample = model.decode(z).cpu()
# Display generated images
plt.figure(figsize=(8, 8))
for i in range(64):
plt.subplot(8, 8, i + 1)
plt.imshow(sample[i].reshape(28, 28), cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
Expected Output
Running this code will train a VAE on the MNIST dataset and then generate new handwritten digit images. The output images will look similar to MNIST digits but will be generated entirely by your model.
2. Generative Adversarial Networks (GANs)
GANs consist of two neural networks competing against each other:
- The Generator creates fake images
- The Discriminator tries to distinguish real images from fake ones
Through this competition, the Generator gets better at creating realistic images.
Implementing a DCGAN for MNIST
Let's implement a Deep Convolutional GAN (DCGAN) for generating MNIST digits:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# Set random seed for reproducibility
torch.manual_seed(42)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyperparameters
batch_size = 128
nz = 100 # Size of the latent z vector
num_epochs = 25
lr = 0.0002
beta1 = 0.5
# Data loading
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True)
# Generator Network
class Generator(nn.Module):
def __init__(self, nz):
super(Generator, self).__init__()
self.main = nn.Sequential(
# Input is latent vector z
nn.Linear(nz, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 784), # 28x28 = 784
nn.Tanh()
)
def forward(self, input):
return self.main(input).view(-1, 1, 28, 28)
# Discriminator Network
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input.view(-1, 784))
# Create the models
netG = Generator(nz).to(device)
netD = Discriminator().to(device)
# Loss function and optimizers
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# Fixed noise for visualization
fixed_noise = torch.randn(64, nz, device=device)
# Training Loop
for epoch in range(num_epochs):
for i, (data, _) in enumerate(train_loader):
# Get batch size
batch_size = data.size(0)
# Configure input
real_imgs = data.to(device)
# Labels
real_label = torch.ones(batch_size, 1).to(device)
fake_label = torch.zeros(batch_size, 1).to(device)
# -----------------
# Train Discriminator
# -----------------
optimizerD.zero_grad()
# Loss on real images
real_pred = netD(real_imgs)
d_real_loss = criterion(real_pred, real_label)
# Loss on fake images
z = torch.randn(batch_size, nz).to(device)
fake_imgs = netG(z)
fake_pred = netD(fake_imgs.detach())
d_fake_loss = criterion(fake_pred, fake_label)
# Total discriminator loss
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizerD.step()
# -----------------
# Train Generator
# -----------------
optimizerG.zero_grad()
# Generate images and get discriminator prediction
gen_pred = netD(fake_imgs)
# Generator loss
g_loss = criterion(gen_pred, real_label)
g_loss.backward()
optimizerG.step()
# Print progress
if epoch % 5 == 0:
print(f"Epoch [{epoch}/{num_epochs}] D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f}")
# Generate and save images
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
plt.figure(figsize=(8, 8))
for j in range(64):
plt.subplot(8, 8, j+1)
plt.imshow(fake[j][0], cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig(f"gan_images_epoch_{epoch}.png")
plt.close()
print("Training complete!")
This GAN will generate increasingly realistic MNIST digits as training progresses.
3. Intro to Diffusion Models
Diffusion models have gained tremendous popularity recently, powering tools like DALL-E and Stable Diffusion. They work by:
- Adding noise to images gradually (forward process)
- Learning to reverse this noise (reverse process)
Let's implement a very simplified diffusion model to understand the concept:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# Simplified UNet architecture for noise prediction
class SimpleUNet(nn.Module):
def __init__(self):
super(SimpleUNet, self).__init__()
# Encoder layers
self.enc1 = nn.Conv2d(1, 64, 3, padding=1)
self.enc2 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
self.enc3 = nn.Conv2d(128, 256, 3, padding=1, stride=2)
# Time embeddings
self.time_mlp = nn.Sequential(
nn.Linear(1, 256),
nn.SiLU(),
nn.Linear(256, 256),
)
# Decoder layers with skip connections
self.dec3 = nn.Conv2d(256 + 256, 128, 3, padding=1)
self.dec2 = nn.Conv2d(128 + 128, 64, 3, padding=1)
self.dec1 = nn.Conv2d(64 + 64, 64, 3, padding=1)
self.final = nn.Conv2d(64, 1, 3, padding=1)
def forward(self, x, t):
# Encoder path
e1 = F.silu(self.enc1(x))
e2 = F.silu(self.enc2(e1))
e3 = F.silu(self.enc3(e2))
# Time embedding
temb = self.time_mlp(t.unsqueeze(1).float())
temb = temb.unsqueeze(2).unsqueeze(3).expand(-1, -1, e3.shape[2], e3.shape[3])
# Combine with time embedding
e3 = e3 + temb
# Decoder path with skip connections
d3 = F.silu(self.dec3(e3))
d3 = F.interpolate(d3, scale_factor=2) # Upsample
d2 = F.silu(self.dec2(torch.cat([d3, e2], dim=1)))
d2 = F.interpolate(d2, scale_factor=2) # Upsample
d1 = F.silu(self.dec1(torch.cat([d2, e1], dim=1)))
return self.final(d1)
# Example usage (not a full training loop)
def demo_diffusion():
# Load a single MNIST image for demonstration
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST('./data', download=True, train=True, transform=transform)
image, _ = mnist[0]
# Model setup
model = SimpleUNet()
# Forward diffusion (add noise)
timesteps = 50
betas = torch.linspace(0.0001, 0.02, timesteps)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# Add noise to the image for different timesteps
noisy_images = []
for t in range(0, timesteps, 10): # Saving every 10th step
alpha_cumprod = alphas_cumprod[t]
noise = torch.randn_like(image)
noisy_image = torch.sqrt(alpha_cumprod) * image + torch.sqrt(1 - alpha_cumprod) * noise
noisy_images.append(noisy_image)
# Visualize noisy images
plt.figure(figsize=(12, 4))
for i, img in enumerate(noisy_images):
plt.subplot(1, len(noisy_images), i+1)
plt.imshow(img[0], cmap='gray')
plt.title(f"t={i*10}")
plt.axis('off')
plt.tight_layout()
plt.show()
print("This is just a visualization of the forward diffusion process.")
print("A full diffusion model would train to predict the noise at each step.")
# Run the demo
demo_diffusion()
This is a simplified demonstration of the diffusion concept. A complete implementation would require significantly more code to handle training and sampling.
Real-World Applications
Generative models are powering many exciting applications:
1. Image Synthesis
- Creating realistic portraits for video games and movie characters
- Generating stock photos without copyright issues
- Creating art in various styles with tools like DALL-E and Midjourney
2. Image-to-Image Translation
- Converting satellite images to maps
- Transforming sketches into photorealistic images
- Colorizing black and white photos
3. Image Enhancement
- Super-resolution (upscaling low-resolution images)
- Image denoising and restoration
- Filling in missing parts of images (inpainting)
Challenges in Generative Modeling
Working with generative models comes with challenges:
- Training instability: GANs can be especially difficult to train
- Mode collapse: Models might generate only a few types of outputs
- Evaluation: It's hard to objectively measure the quality of generated images
- Computational requirements: Training advanced models often needs significant GPU resources
Summary
We've explored three major types of generative models in PyTorch:
- Variational Autoencoders (VAEs): Great for learning compressed representations of data and generating new samples.
- Generative Adversarial Networks (GANs): Powerful for generating highly realistic images through competition between generator and discriminator.
- Diffusion Models: The newest approach, gradually removing noise to create high-quality images.
Each model has its strengths and ideal use cases, and they form the backbone of many exciting computer vision applications.
Additional Resources and Exercises
Resources for Further Learning:
- PyTorch's official GAN tutorial
- Hugging Face's Diffusion Models course
- Stanford's CS231n lecture on Generative Models
Exercises:
- Modify the VAE to work with a color image dataset like CIFAR-10.
- Add conditioning to the GAN implementation to generate specific MNIST digits.
- Implement CycleGAN for image-to-image translation between two domains.
- Build an image inpainting system using one of the generative models we discussed.
- Compare the image quality of outputs from VAE, GAN, and a pre-trained diffusion model.
With these fundamentals, you're well on your way to creating your own generative AI applications!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)