PyTorch GANs
Introduction to Generative Adversarial Networks
Generative Adversarial Networks (GANs) represent one of the most exciting advancements in deep learning in recent years. Introduced by Ian Goodfellow in 2014, GANs consist of two neural networks—a generator and a discriminator—that are trained simultaneously in a competitive setting.
- The generator learns to create fake data that resembles real data
- The discriminator learns to distinguish between real and fake data
This adversarial process continues until the generator produces data that the discriminator cannot differentiate from real data.
In this tutorial, we'll learn how to implement a basic GAN using PyTorch to generate handwritten digits similar to those in the MNIST dataset.
Understanding the GAN Architecture
Before diving into code, let's understand the core components and workflow of a GAN:
- Generator: Takes random noise as input and outputs synthetic data
- Discriminator: Takes either real data or generated data and outputs a probability indicating whether the input is real or fake
- Training Process: Alternates between training the discriminator and the generator
Prerequisites
To follow along with this tutorial, you'll need:
pip install torch torchvision numpy matplotlib
Setting Up Our Environment
Let's start by importing the necessary libraries:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
Loading the MNIST Dataset
We'll use the MNIST dataset of handwritten digits:
# Define image transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1]
])
# Load MNIST dataset
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
# Create data loader
batch_size = 64
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
Building the Discriminator
The discriminator is a binary classifier that tries to distinguish between real and generated images:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# A simple network with convolutional layers
self.model = nn.Sequential(
# Input is 1x28x28
nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
# 32x14x14
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# 64x7x7
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128x3x3
nn.Flatten(),
nn.Linear(128 * 3 * 3, 1),
nn.Sigmoid() # Output is probability between 0 and 1
)
def forward(self, x):
return self.model(x)
Building the Generator
The generator creates images from random noise:
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.latent_dim = latent_dim
# Initial size based on upsampling from random noise
self.init_size = 7 # Will be upsampled to 28x28
self.l1 = nn.Sequential(
nn.Linear(latent_dim, 128 * self.init_size * self.init_size)
)
# Upsample to 28x28 through convolutional layers
self.model = nn.Sequential(
nn.BatchNorm2d(128),
# 128x7x7
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128x14x14
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# 64x28x28
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
nn.Tanh() # Output normalized between -1 and 1
)
def forward(self, z):
# Project and reshape the noise
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
# Generate the image
img = self.model(out)
return img
Initializing Models and Optimizers
Now, let's initialize our models and optimizers:
# Hyperparameters
latent_dim = 100 # Dimension of the random noise
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
# Binary cross entropy loss
adversarial_loss = nn.BCELoss()
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
Utility Functions for Visualization
Let's create a function to visualize our generated images:
def show_images(images, num_images=25):
"""Display a batch of images in a grid."""
# Convert tensor images to numpy images
images = images.detach().cpu().numpy()
# Rescale from [-1, 1] to [0, 1]
images = (images + 1) / 2.0
# Create a grid
n = int(np.sqrt(num_images))
plt.figure(figsize=(8, 8))
for i in range(num_images):
plt.subplot(n, n, i + 1)
plt.imshow(images[i, 0], cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
Training the GAN
Now let's put everything together and train our GAN:
# Training parameters
num_epochs = 50
sample_interval = 10 # Generate sample images every 'sample_interval' epochs
samples_to_generate = 25
# Create fixed noise for visualization
fixed_noise = torch.randn(samples_to_generate, latent_dim, device=device)
# Training loop
for epoch in range(num_epochs):
for batch_idx, (real_imgs, _) in enumerate(train_loader):
# Move data to device
real_imgs = real_imgs.to(device)
batch_size = real_imgs.size(0)
# Labels for real and fake images
real_label = torch.ones(batch_size, 1, device=device)
fake_label = torch.zeros(batch_size, 1, device=device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss on real images
real_pred = discriminator(real_imgs)
d_loss_real = adversarial_loss(real_pred, real_label)
# Loss on fake images
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z)
fake_pred = discriminator(fake_imgs.detach())
d_loss_fake = adversarial_loss(fake_pred, fake_label)
# Total discriminator loss
d_loss = (d_loss_real + d_loss_fake) / 2
d_loss.backward()
optimizer_D.step()
# ---------------------
# Train Generator
# ---------------------
optimizer_G.zero_grad()
# Generate new fake images
fake_imgs = generator(z)
fake_pred = discriminator(fake_imgs)
# Try to fool the discriminator
g_loss = adversarial_loss(fake_pred, real_label)
g_loss.backward()
optimizer_G.step()
# Print progress
if batch_idx % 100 == 0:
print(
f"[Epoch {epoch}/{num_epochs}] "
f"[Batch {batch_idx}/{len(train_loader)}] "
f"[D loss: {d_loss.item():.4f}] "
f"[G loss: {g_loss.item():.4f}]"
)
# Generate and show example images
if (epoch + 1) % sample_interval == 0:
print(f"Generating sample images for epoch {epoch+1}...")
generator.eval()
with torch.no_grad():
gen_imgs = generator(fixed_noise)
show_images(gen_imgs, samples_to_generate)
generator.train()
When you run this code, you'll see the progress of the training and sample images displayed every sample_interval
epochs. Over time, the generated images should become more and more like real handwritten digits.
Expected Results
After training for 50 epochs, you should see generated images that resemble handwritten digits from the MNIST dataset. Here's what you might expect from the generated output at different stages:
- Early epochs (1-10): Blurry, barely recognizable shapes
- Middle epochs (11-30): More defined digit-like shapes starting to appear
- Later epochs (31-50): Fairly clear digits that resemble the MNIST dataset
Common Issues and Tips for Training GANs
GANs can be notoriously difficult to train. Here are some tips if you encounter problems:
-
Mode collapse: If your generator produces limited varieties of images, try:
- Adding noise to the discriminator inputs
- Using different loss functions like Wasserstein loss
-
Unstable training: If losses fluctuate wildly:
- Reduce learning rate
- Implement gradient clipping
- Add batch normalization
-
Poor quality images:
- Train for more epochs
- Use a deeper architecture
- Try different activation functions
Practical Applications of GANs
GANs have numerous real-world applications:
- Image generation: Creating realistic images, artwork, or designs
- Image-to-image translation: Converting sketches to photos, day to night scenes
- Super-resolution: Enhancing low-resolution images
- Data augmentation: Generating additional training data for other models
- Drug discovery: Generating molecular structures
- Fashion and design: Creating new clothing or product designs
Example: Generating a Specific Digit
Let's extend our GAN to generate a specific digit. We'll use a conditional GAN (cGAN) approach:
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim, num_classes=10):
super(ConditionalGenerator, self).__init__()
self.latent_dim = latent_dim
self.label_embedding = nn.Embedding(num_classes, num_classes)
self.init_size = 7
self.l1 = nn.Sequential(
nn.Linear(latent_dim + num_classes, 128 * self.init_size * self.init_size)
)
self.model = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, z, labels):
label_embedding = self.label_embedding(labels)
x = torch.cat([z, label_embedding], dim=1)
out = self.l1(x)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.model(out)
return img
# Example usage:
# cond_generator = ConditionalGenerator(latent_dim)
# z = torch.randn(batch_size, latent_dim, device=device)
# labels = torch.randint(0, 10, (batch_size,), device=device)
# fake_imgs = cond_generator(z, labels)
Summary
In this tutorial, you've learned:
- What GANs are: Two neural networks (generator and discriminator) that work in an adversarial setting
- How to implement a basic GAN in PyTorch: Building and training the generator and discriminator
- Tips for training GANs: Common issues and their solutions
- Practical applications: Various real-world uses of GANs
- Advanced concepts: Introduction to conditional GANs
GANs represent a powerful approach to generative modeling and have revolutionized the field of artificial intelligence, particularly in image synthesis. With the knowledge gained from this tutorial, you can now experiment with different GAN architectures and applications.
Additional Resources and Exercises
Further Learning Resources
- Original GAN Paper by Ian Goodfellow
- DCGAN Paper: Unsupervised Representation Learning with DCGANs
- Conditional GANs Paper
- StyleGAN: A Style-Based Generator Architecture for GANs
Exercises
- Modify the architecture: Try changing the network architecture to see how it affects the results.
- Generate other datasets: Adapt the code to work with other datasets like CIFAR-10 or Fashion-MNIST.
- Implement a DCGAN: Refactor the code to follow the Deep Convolutional GAN (DCGAN) architecture.
- Create a conditional GAN: Extend the GAN to generate images based on class labels.
- Implement different loss functions: Try using Wasserstein loss or least squares loss instead of binary cross-entropy.
By completing these exercises, you'll gain a deeper understanding of GANs and improve your PyTorch skills.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)