Skip to main content

PyTorch GANs

Introduction to Generative Adversarial Networks

Generative Adversarial Networks (GANs) represent one of the most exciting advancements in deep learning in recent years. Introduced by Ian Goodfellow in 2014, GANs consist of two neural networks—a generator and a discriminator—that are trained simultaneously in a competitive setting.

  • The generator learns to create fake data that resembles real data
  • The discriminator learns to distinguish between real and fake data

This adversarial process continues until the generator produces data that the discriminator cannot differentiate from real data.

In this tutorial, we'll learn how to implement a basic GAN using PyTorch to generate handwritten digits similar to those in the MNIST dataset.

Understanding the GAN Architecture

Before diving into code, let's understand the core components and workflow of a GAN:

  1. Generator: Takes random noise as input and outputs synthetic data
  2. Discriminator: Takes either real data or generated data and outputs a probability indicating whether the input is real or fake
  3. Training Process: Alternates between training the discriminator and the generator

GAN Architecture Diagram

Prerequisites

To follow along with this tutorial, you'll need:

bash
pip install torch torchvision numpy matplotlib

Setting Up Our Environment

Let's start by importing the necessary libraries:

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

Loading the MNIST Dataset

We'll use the MNIST dataset of handwritten digits:

python
# Define image transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1]
])

# Load MNIST dataset
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)

# Create data loader
batch_size = 64
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)

Building the Discriminator

The discriminator is a binary classifier that tries to distinguish between real and generated images:

python
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

# A simple network with convolutional layers
self.model = nn.Sequential(
# Input is 1x28x28
nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
# 32x14x14
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# 64x7x7
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128x3x3
nn.Flatten(),
nn.Linear(128 * 3 * 3, 1),
nn.Sigmoid() # Output is probability between 0 and 1
)

def forward(self, x):
return self.model(x)

Building the Generator

The generator creates images from random noise:

python
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()

self.latent_dim = latent_dim

# Initial size based on upsampling from random noise
self.init_size = 7 # Will be upsampled to 28x28
self.l1 = nn.Sequential(
nn.Linear(latent_dim, 128 * self.init_size * self.init_size)
)

# Upsample to 28x28 through convolutional layers
self.model = nn.Sequential(
nn.BatchNorm2d(128),
# 128x7x7
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128x14x14
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# 64x28x28
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
nn.Tanh() # Output normalized between -1 and 1
)

def forward(self, z):
# Project and reshape the noise
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)

# Generate the image
img = self.model(out)
return img

Initializing Models and Optimizers

Now, let's initialize our models and optimizers:

python
# Hyperparameters
latent_dim = 100 # Dimension of the random noise
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Binary cross entropy loss
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

Utility Functions for Visualization

Let's create a function to visualize our generated images:

python
def show_images(images, num_images=25):
"""Display a batch of images in a grid."""
# Convert tensor images to numpy images
images = images.detach().cpu().numpy()
# Rescale from [-1, 1] to [0, 1]
images = (images + 1) / 2.0

# Create a grid
n = int(np.sqrt(num_images))
plt.figure(figsize=(8, 8))
for i in range(num_images):
plt.subplot(n, n, i + 1)
plt.imshow(images[i, 0], cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()

Training the GAN

Now let's put everything together and train our GAN:

python
# Training parameters
num_epochs = 50
sample_interval = 10 # Generate sample images every 'sample_interval' epochs
samples_to_generate = 25

# Create fixed noise for visualization
fixed_noise = torch.randn(samples_to_generate, latent_dim, device=device)

# Training loop
for epoch in range(num_epochs):
for batch_idx, (real_imgs, _) in enumerate(train_loader):
# Move data to device
real_imgs = real_imgs.to(device)
batch_size = real_imgs.size(0)

# Labels for real and fake images
real_label = torch.ones(batch_size, 1, device=device)
fake_label = torch.zeros(batch_size, 1, device=device)

# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()

# Loss on real images
real_pred = discriminator(real_imgs)
d_loss_real = adversarial_loss(real_pred, real_label)

# Loss on fake images
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z)
fake_pred = discriminator(fake_imgs.detach())
d_loss_fake = adversarial_loss(fake_pred, fake_label)

# Total discriminator loss
d_loss = (d_loss_real + d_loss_fake) / 2

d_loss.backward()
optimizer_D.step()

# ---------------------
# Train Generator
# ---------------------
optimizer_G.zero_grad()

# Generate new fake images
fake_imgs = generator(z)
fake_pred = discriminator(fake_imgs)

# Try to fool the discriminator
g_loss = adversarial_loss(fake_pred, real_label)

g_loss.backward()
optimizer_G.step()

# Print progress
if batch_idx % 100 == 0:
print(
f"[Epoch {epoch}/{num_epochs}] "
f"[Batch {batch_idx}/{len(train_loader)}] "
f"[D loss: {d_loss.item():.4f}] "
f"[G loss: {g_loss.item():.4f}]"
)

# Generate and show example images
if (epoch + 1) % sample_interval == 0:
print(f"Generating sample images for epoch {epoch+1}...")
generator.eval()
with torch.no_grad():
gen_imgs = generator(fixed_noise)
show_images(gen_imgs, samples_to_generate)
generator.train()

When you run this code, you'll see the progress of the training and sample images displayed every sample_interval epochs. Over time, the generated images should become more and more like real handwritten digits.

Expected Results

After training for 50 epochs, you should see generated images that resemble handwritten digits from the MNIST dataset. Here's what you might expect from the generated output at different stages:

  • Early epochs (1-10): Blurry, barely recognizable shapes
  • Middle epochs (11-30): More defined digit-like shapes starting to appear
  • Later epochs (31-50): Fairly clear digits that resemble the MNIST dataset

Common Issues and Tips for Training GANs

GANs can be notoriously difficult to train. Here are some tips if you encounter problems:

  1. Mode collapse: If your generator produces limited varieties of images, try:

    • Adding noise to the discriminator inputs
    • Using different loss functions like Wasserstein loss
  2. Unstable training: If losses fluctuate wildly:

    • Reduce learning rate
    • Implement gradient clipping
    • Add batch normalization
  3. Poor quality images:

    • Train for more epochs
    • Use a deeper architecture
    • Try different activation functions

Practical Applications of GANs

GANs have numerous real-world applications:

  1. Image generation: Creating realistic images, artwork, or designs
  2. Image-to-image translation: Converting sketches to photos, day to night scenes
  3. Super-resolution: Enhancing low-resolution images
  4. Data augmentation: Generating additional training data for other models
  5. Drug discovery: Generating molecular structures
  6. Fashion and design: Creating new clothing or product designs

Example: Generating a Specific Digit

Let's extend our GAN to generate a specific digit. We'll use a conditional GAN (cGAN) approach:

python
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim, num_classes=10):
super(ConditionalGenerator, self).__init__()

self.latent_dim = latent_dim
self.label_embedding = nn.Embedding(num_classes, num_classes)

self.init_size = 7
self.l1 = nn.Sequential(
nn.Linear(latent_dim + num_classes, 128 * self.init_size * self.init_size)
)

self.model = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)

def forward(self, z, labels):
label_embedding = self.label_embedding(labels)
x = torch.cat([z, label_embedding], dim=1)

out = self.l1(x)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)

img = self.model(out)
return img

# Example usage:
# cond_generator = ConditionalGenerator(latent_dim)
# z = torch.randn(batch_size, latent_dim, device=device)
# labels = torch.randint(0, 10, (batch_size,), device=device)
# fake_imgs = cond_generator(z, labels)

Summary

In this tutorial, you've learned:

  1. What GANs are: Two neural networks (generator and discriminator) that work in an adversarial setting
  2. How to implement a basic GAN in PyTorch: Building and training the generator and discriminator
  3. Tips for training GANs: Common issues and their solutions
  4. Practical applications: Various real-world uses of GANs
  5. Advanced concepts: Introduction to conditional GANs

GANs represent a powerful approach to generative modeling and have revolutionized the field of artificial intelligence, particularly in image synthesis. With the knowledge gained from this tutorial, you can now experiment with different GAN architectures and applications.

Additional Resources and Exercises

Further Learning Resources

  1. Original GAN Paper by Ian Goodfellow
  2. DCGAN Paper: Unsupervised Representation Learning with DCGANs
  3. Conditional GANs Paper
  4. StyleGAN: A Style-Based Generator Architecture for GANs

Exercises

  1. Modify the architecture: Try changing the network architecture to see how it affects the results.
  2. Generate other datasets: Adapt the code to work with other datasets like CIFAR-10 or Fashion-MNIST.
  3. Implement a DCGAN: Refactor the code to follow the Deep Convolutional GAN (DCGAN) architecture.
  4. Create a conditional GAN: Extend the GAN to generate images based on class labels.
  5. Implement different loss functions: Try using Wasserstein loss or least squares loss instead of binary cross-entropy.

By completing these exercises, you'll gain a deeper understanding of GANs and improve your PyTorch skills.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)