Skip to main content

PyTorch Autoencoders

Autoencoders are a special type of neural network architecture designed to learn efficient data representations in an unsupervised manner. In this tutorial, you'll learn how autoencoders work, how to implement them in PyTorch, and explore some of their real-world applications.

What Are Autoencoders?

An autoencoder is a neural network that learns to copy its input to its output. This might seem trivial, but there's a catch: the network is designed with a bottleneck (a narrow layer) in the middle, forcing it to learn a compressed representation of the input data.

Autoencoders consist of two main parts:

  • An encoder that compresses the input into a latent-space representation
  • A decoder that reconstructs the input from the latent-space representation

Basic Autoencoder Implementation

Let's build a simple autoencoder for the MNIST dataset (handwritten digits).

First, import the necessary libraries:

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

1. Preparing the Data

python
# Define transformations
transform = transforms.Compose([
transforms.ToTensor(),
])

# Download and load the MNIST dataset
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)

test_dataset = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transform
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

2. Building the Autoencoder Model

python
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()

# Encoder
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 16) # Latent space dimension
)

# Decoder
self.decoder = nn.Sequential(
nn.Linear(16, 32),
nn.ReLU(),
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
nn.Sigmoid() # Output values between 0 and 1
)

def forward(self, x):
x = x.view(-1, 28 * 28) # Flatten the input
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded.view(-1, 1, 28, 28) # Reshape back to image format

3. Training the Autoencoder

python
# Initialize the model, loss function, and optimizer
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 10
outputs = []

for epoch in range(num_epochs):
total_loss = 0
for data in train_loader:
img, _ = data

# Forward pass
output = model(img)
loss = criterion(output, img)

# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

# Print loss at each epoch
avg_loss = total_loss / len(train_loader)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

# Save the output images
outputs.append((epoch, img, output))

print("Training complete!")

4. Visualizing the Results

python
def plot_reconstruction(epoch, original, reconstructed):
plt.figure(figsize=(10, 4))

# Plot original images
for i in range(5):
plt.subplot(2, 5, i+1)
plt.imshow(original[i][0].detach().numpy(), cmap='gray')
plt.title("Original")
plt.axis('off')

# Plot reconstructed images
for i in range(5):
plt.subplot(2, 5, i+6)
plt.imshow(reconstructed[i][0].detach().numpy(), cmap='gray')
plt.title("Reconstructed")
plt.axis('off')

plt.suptitle(f"Epoch {epoch+1}")
plt.tight_layout()
plt.show()

# Visualize the last epoch result
epoch, imgs, outputs = outputs[-1]
plot_reconstruction(epoch, imgs, outputs)

Expected output:

  • A plot showing original and reconstructed MNIST digits side by side
  • The reconstructed images should look similar to the originals, demonstrating that the autoencoder has learned an efficient representation

Convolutional Autoencoder

For image data, convolutional autoencoders often perform better than fully connected ones. Let's implement a convolutional autoencoder:

python
class ConvAutoencoder(nn.Module):
def __init__(self):
super(ConvAutoencoder, self).__init__()

# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)

# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)

def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded

# Initialize and train the model (similar to the previous example)
conv_model = ConvAutoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(conv_model.parameters(), lr=1e-3)

# Training loop would be similar to the previous example

Variational Autoencoder (VAE)

Variational autoencoders are a powerful extension that learn a probabilistic latent space distribution, allowing for better generative capabilities.

python
class VAE(nn.Module):
def __init__(self, latent_dim=16):
super(VAE, self).__init__()

# Encoder
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU()
)

# Mean and variance layers
self.fc_mu = nn.Linear(128, latent_dim)
self.fc_var = nn.Linear(128, latent_dim)

# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 28 * 28),
nn.Sigmoid()
)

def encode(self, x):
x = x.view(-1, 28 * 28)
h = self.encoder(x)
mu = self.fc_mu(h)
log_var = self.fc_var(h)
return mu, log_var

def reparameterize(self, mu, log_var):
# Reparameterization trick
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + eps * std
return z

def decode(self, z):
output = self.decoder(z)
return output.view(-1, 1, 28, 28)

def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
decoded = self.decode(z)
return decoded, mu, log_var

# Modified loss function for VAE
def vae_loss_function(recon_x, x, mu, log_var):
# Reconstruction loss (binary cross-entropy)
BCE = nn.functional.binary_cross_entropy(
recon_x.view(-1, 28 * 28),
x.view(-1, 28 * 28),
reduction='sum'
)

# KL divergence
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

return BCE + KLD

Denoising Autoencoder

Denoising autoencoders learn to remove noise from corrupted inputs, making them useful for image restoration and enhancement.

python
class DenoisingAutoencoder(nn.Module):
def __init__(self):
super(DenoisingAutoencoder, self).__init__()
# Same architecture as our first autoencoder
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32)
)

self.decoder = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
nn.Sigmoid()
)

def forward(self, x):
x = x.view(-1, 28 * 28)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded.view(-1, 1, 28, 28)

# Function to add noise to images
def add_noise(images, noise_factor=0.5):
noisy_images = images + noise_factor * torch.randn_like(images)
return torch.clamp(noisy_images, 0., 1.)

# Training loop for denoising autoencoder would include:
# 1. Adding noise to original images
# 2. Feeding noisy images to the autoencoder
# 3. Computing loss between original (not noisy) images and reconstructions

Real-world Applications of Autoencoders

Autoencoders have numerous practical applications:

1. Image Compression and Denoising

python
# Example: Using a trained denoising autoencoder to clean a noisy image
model = DenoisingAutoencoder()
# Load the trained model weights
model.load_state_dict(torch.load("denoising_autoencoder.pth"))
model.eval()

# Load and prepare a noisy image
noisy_image = torch.randn(1, 1, 28, 28) * 0.3 + test_dataset[0][0]
noisy_image = torch.clamp(noisy_image, 0, 1)

# Denoise the image
with torch.no_grad():
cleaned_image = model(noisy_image)

# Plot the results
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(test_dataset[0][0].squeeze(), cmap='gray')
plt.subplot(1, 3, 2)
plt.title("Noisy Image")
plt.imshow(noisy_image.squeeze(), cmap='gray')
plt.subplot(1, 3, 3)
plt.title("Cleaned Image")
plt.imshow(cleaned_image.squeeze(), cmap='gray')
plt.show()

2. Anomaly Detection

Autoencoders can detect anomalies by comparing reconstruction errors:

python
def detect_anomalies(model, data_loader, threshold):
model.eval()
anomalies = []

with torch.no_grad():
for data in data_loader:
img, label = data

# Get reconstruction
output = model(img)

# Calculate reconstruction error
error = torch.mean((output - img) ** 2, dim=(1, 2, 3))

# Detect anomalies based on threshold
for i, err in enumerate(error):
if err > threshold:
anomalies.append((img[i], label[i], err.item()))

return anomalies

3. Dimensionality Reduction and Feature Learning

Autoencoders can be used as alternatives to PCA for dimensionality reduction:

python
def extract_features(model, data_loader):
model.eval()
features = []
labels = []

with torch.no_grad():
for data in data_loader:
img, label = data

# Extract features from the encoder
x = img.view(-1, 28 * 28)
encoded_features = model.encoder(x)

features.append(encoded_features)
labels.append(label)

# Concatenate all batches
features = torch.cat(features, dim=0)
labels = torch.cat(labels, dim=0)

return features, labels

# Visualize features with t-SNE
from sklearn.manifold import TSNE
import numpy as np

# Extract features from test set
features, labels = extract_features(model, test_loader)

# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42)
features_2d = tsne.fit_transform(features.numpy())

# Plot the results
plt.figure(figsize=(10, 8))
for digit in range(10):
idx = labels.numpy() == digit
plt.scatter(features_2d[idx, 0], features_2d[idx, 1], label=f"Digit {digit}")
plt.legend()
plt.title("t-SNE visualization of autoencoder features")
plt.show()

Practical Tips for Training Autoencoders

  1. Model Architecture: Choose the complexity of your autoencoder based on your data. More complex data might require deeper architectures.

  2. Latent Space Dimensionality: The size of the bottleneck layer affects the quality of reconstruction and the compression ratio. Experiment to find the optimal size.

  3. Loss Function: For image data, consider using perceptual losses or structural similarity index (SSIM) instead of just MSE.

  4. Regularization: Add regularization to prevent overfitting, especially for smaller datasets.

python
# Example of L1 regularization in the loss function
def regularized_loss(recon_x, x, model, lambda_reg=0.001):
mse_loss = nn.functional.mse_loss(recon_x, x)

# L1 regularization on weights
l1_reg = 0
for param in model.parameters():
l1_reg += torch.sum(torch.abs(param))

return mse_loss + lambda_reg * l1_reg

Summary

Autoencoders are versatile neural network architectures that learn to encode data into a lower-dimensional latent space and then reconstruct it. We've covered:

  • Basic autoencoder architecture with fully connected layers
  • Convolutional autoencoders for image data
  • Variational autoencoders that learn a probabilistic latent space
  • Denoising autoencoders for noise removal
  • Real-world applications including image compression, denoising, anomaly detection, and dimensionality reduction

Autoencoders are foundational models in unsupervised learning and generative modeling, forming the basis for more advanced techniques in the deep learning field.

Additional Resources

Exercises

  1. Basic Exercise: Implement a simple autoencoder for the MNIST dataset and visualize the reconstructions.

  2. Intermediate Exercise: Create a denoising autoencoder and test it with different noise levels.

  3. Advanced Exercise: Implement a variational autoencoder and generate new, synthetic images by sampling from the latent space.

  4. Challenge: Build a conditional autoencoder that can reconstruct specific digits based on a conditioning label.

Happy coding with PyTorch autoencoders!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)