Skip to main content

TensorFlow Variational Autoencoders

Introduction

Variational Autoencoders (VAEs) represent a fascinating intersection of deep learning and probabilistic modeling. Unlike traditional autoencoders that simply compress and reconstruct data, VAEs learn the underlying probability distribution of the input data, making them powerful generative models.

In this tutorial, we'll explore VAEs using TensorFlow, starting from the basic theory and gradually implementing a complete model that can generate new data samples. By the end of this guide, you'll understand:

  • The theory behind VAEs and how they differ from traditional autoencoders
  • How to implement VAEs in TensorFlow
  • Ways to train and evaluate VAE models
  • Practical applications of VAEs in real-world scenarios

What is a Variational Autoencoder?

A Variational Autoencoder is a type of generative model that learns to encode input data into a latent space (a compressed representation) and then decode it back to reconstruct the original input. What makes VAEs special is that they don't just learn a deterministic mapping but rather learn parameters of probability distributions.

Traditional Autoencoders vs. VAEs

Traditional autoencoders map inputs to a fixed point in a latent space, while VAEs map inputs to a probability distribution in that space. This probabilistic approach allows VAEs to:

  1. Generate new, realistic samples by sampling from the learned distribution
  2. Create a smooth, continuous latent space where similar inputs are close together
  3. Handle uncertainty in a principled way

The Mathematics Behind VAEs

VAEs consist of two main components:

  1. Encoder: Maps input data to parameters of a probability distribution (usually Gaussian)
  2. Decoder: Maps samples from this distribution back to the input space

The training objective of a VAE combines:

  • Reconstruction loss (how well the decoded output matches the original input)
  • KL divergence (ensures the learned distribution is close to a standard normal distribution)

Implementing a VAE in TensorFlow

Let's implement a simple VAE for the MNIST dataset:

python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

Step 1: Prepare the Dataset

python
# Load and preprocess MNIST dataset
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()

# Normalize and reshape the data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape to (batch_size, height, width, channels)
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

# Create TF Dataset
batch_size = 128
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

test_dataset = tf.data.Dataset.from_tensor_slices(x_test).batch(batch_size)

Step 2: Define the VAE Model

python
class Sampling(keras.layers.Layer):
"""
Custom layer for the reparameterization trick
"""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class VAE(keras.Model):
def __init__(self, latent_dim=2, **kwargs):
super(VAE, self).__init__(**kwargs)
self.latent_dim = latent_dim

# Define the encoder
self.encoder = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28, 1)),
keras.layers.Conv2D(32, 3, strides=2, padding='same', activation='relu'),
keras.layers.Conv2D(64, 3, strides=2, padding='same', activation='relu'),
keras.layers.Flatten(),
keras.layers.Dense(16, activation='relu'),
])

# Define separate layers for mean and variance
self.dense_mean = keras.layers.Dense(latent_dim)
self.dense_log_var = keras.layers.Dense(latent_dim)

# The sampling layer
self.sampling = Sampling()

# Define the decoder
self.decoder = keras.Sequential([
keras.layers.InputLayer(input_shape=(latent_dim,)),
keras.layers.Dense(7 * 7 * 64, activation='relu'),
keras.layers.Reshape((7, 7, 64)),
keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu'),
keras.layers.Conv2DTranspose(32, 3, strides=2, padding='same', activation='relu'),
keras.layers.Conv2DTranspose(1, 3, padding='same', activation='sigmoid'),
])

def encode(self, x):
x = self.encoder(x)
z_mean = self.dense_mean(x)
z_log_var = self.dense_log_var(x)
return z_mean, z_log_var

def decode(self, z, apply_sigmoid=False):
logits = self.decoder(z)
return logits

def call(self, inputs):
z_mean, z_log_var = self.encode(inputs)
z = self.sampling((z_mean, z_log_var))
reconstructed = self.decode(z)
return reconstructed

def train_step(self, data):
with tf.GradientTape() as tape:
# Encode and get latent parameters
z_mean, z_log_var = self.encode(data)

# Sample from the latent distribution
z = self.sampling((z_mean, z_log_var))

# Decode
reconstructed = self.decode(z)

# Compute reconstruction loss
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstructed),
axis=(1, 2, 3)
)
)

# Compute KL divergence loss
kl_loss = -0.5 * tf.reduce_mean(
tf.reduce_sum(
1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
axis=1
)
)

# Total loss
total_loss = reconstruction_loss + kl_loss

# Compute gradients
grads = tape.gradient(total_loss, self.trainable_weights)

# Update weights
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}

Step 3: Train the VAE Model

python
# Create and compile the VAE model
latent_dim = 2 # 2D latent space for easy visualization
vae = VAE(latent_dim)
vae.compile(optimizer=keras.optimizers.Adam(1e-3))

# Train the model
epochs = 10
history = vae.fit(train_dataset, epochs=epochs, batch_size=batch_size)

# Let's plot the training loss
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'])
plt.title('VAE Training Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.grid(True)
plt.show()

Step 4: Visualize and Use the VAE

python
# Visualize the latent space
def plot_latent_space(vae, n=30, figsize=15):
# Display a 2D grid of digits decoded from the latent space
figure = np.zeros((28 * n, 28 * n))

# Linearly spaced coordinates in the latent space
grid_x = np.linspace(-3, 3, n)
grid_y = np.linspace(-3, 3, n)[::-1]

for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = vae.decode(z_sample)
digit = tf.reshape(x_decoded[0], (28, 28))
figure[i * 28: (i + 1) * 28, j * 28: (j + 1) * 28] = digit.numpy()

plt.figure(figsize=(figsize, figsize))
start_range = 0.5 * 28 # Offset for first digit
end_range = (n - 0.5) * 28 # Offset for last digit
pixel_range = np.arange(start_range, end_range, 28)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()

# Visualize the latent space
plot_latent_space(vae)

Expected output: The above code will display a grid of generated MNIST digits. As you move through the latent space, you'll see the digits gradually transform from one to another, demonstrating the smooth transition property of the VAE's latent space.

Step 5: Generate New Samples

python
# Generate new digits by sampling from latent space
def generate_samples(vae, num_samples=10):
# Sample random points in the latent space
z = tf.random.normal(shape=(num_samples, latent_dim))
# Decode the samples
samples = vae.decode(z)

# Plot the generated digits
fig, axes = plt.subplots(1, num_samples, figsize=(num_samples*2, 2))
for i, sample in enumerate(samples):
axes[i].imshow(sample.numpy().reshape(28, 28), cmap='gray')
axes[i].axis('off')
plt.tight_layout()
plt.show()

# Generate 10 random digits
generate_samples(vae, 10)

Expected output: This will display a row of 10 randomly generated digits. Since the model has learned the distribution of the MNIST dataset, these should look like realistic handwritten digits.

Understanding the Code

The Reparameterization Trick

One of the key innovations in VAEs is the reparameterization trick, implemented in our Sampling layer. This trick allows us to backpropagate through the random sampling process:

python
class Sampling(keras.layers.Layer):
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon

Instead of directly sampling from a distribution (which would break gradient flow), we sample from a standard normal distribution and transform the samples using the learned mean and variance.

The Loss Function

The VAE loss has two components:

  1. Reconstruction Loss: How well the decoder reconstructs the input from the latent representation
  2. KL Divergence: Forces the learned latent distribution to be close to a standard normal distribution
python
# Compute reconstruction loss
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstructed),
axis=(1, 2, 3)
)
)

# Compute KL divergence loss
kl_loss = -0.5 * tf.reduce_mean(
tf.reduce_sum(
1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
axis=1
)
)

# Total loss
total_loss = reconstruction_loss + kl_loss

The KL divergence term acts as a regularizer, preventing the model from simply memorizing the training data.

Real-World Applications of VAEs

VAEs have numerous applications across different domains:

1. Image Generation and Editing

python
# Simple example of latent space manipulation for image editing
def interpolate_digits(vae, digit1, digit2, steps=10):
# Get latent representations for two digits
z_mean1, _ = vae.encode(digit1.reshape(1, 28, 28, 1))
z_mean2, _ = vae.encode(digit2.reshape(1, 28, 28, 1))

# Create interpolation points in the latent space
alphas = np.linspace(0, 1, steps)
z_interp = np.zeros((steps, latent_dim))

for i, alpha in enumerate(alphas):
z_interp[i] = alpha * z_mean2 + (1 - alpha) * z_mean1

# Decode the interpolated points
decoded = vae.decode(z_interp)

# Plot the results
fig, axes = plt.subplots(1, steps, figsize=(steps*2, 2))
for i in range(steps):
axes[i].imshow(decoded[i].numpy().reshape(28, 28), cmap='gray')
axes[i].axis('off')
plt.tight_layout()
plt.show()

# Get two test digits
digit1 = x_test[0] # First test digit
digit2 = x_test[1] # Second test digit

# Interpolate between them
interpolate_digits(vae, digit1, digit2)

2. Anomaly Detection

VAEs can identify anomalous data by measuring reconstruction error:

python
def detect_anomalies(vae, data, threshold=100):
# Compute reconstruction error for each sample
reconstructed = vae(data)
mse = tf.reduce_mean(tf.square(data - reconstructed), axis=[1, 2, 3])

# Identify anomalies (samples with high reconstruction error)
anomalies = data[mse > threshold]
normal = data[mse <= threshold]

print(f"Found {len(anomalies)} anomalies in {len(data)} samples")
return anomalies, normal, mse

# This would require actual anomalous data for demonstration

3. Drug Discovery

VAEs are used in computational drug discovery to generate novel molecular structures:

python
# This is pseudocode to illustrate the concept
"""
# Encode known drug molecules into latent space
z_mean, z_log_var = molecular_vae.encode(known_drugs)

# Sample new points in latent space
z_new = sample_from_latent_space(n_samples=100)

# Decode to get novel molecular structures
new_molecules = molecular_vae.decode(z_new)

# Filter based on drug-likeness properties
promising_drugs = filter_by_properties(new_molecules)
"""

Advanced VAE Variants

Conditional VAE (CVAE)

We can extend our VAE to generate samples conditioned on specific attributes:

python
class CVAE(keras.Model):
def __init__(self, latent_dim=2, num_classes=10, **kwargs):
super(CVAE, self).__init__(**kwargs)
self.latent_dim = latent_dim
self.num_classes = num_classes

# Encoder and decoder networks would be modified to accept
# both the input data and the conditional information

# For example, the encoder would concatenate the image with a one-hot encoded class
# And similarly for the decoder

β-VAE

β-VAE introduces a hyperparameter β that balances reconstruction quality against learning of disentangled latent representations:

python
# In the train_step method:
total_loss = reconstruction_loss + beta * kl_loss
# where beta > 1 puts more emphasis on learning disentangled representations

Summary

In this tutorial, we've explored Variational Autoencoders in TensorFlow, covering:

  1. The theoretical foundation of VAEs and how they differ from traditional autoencoders
  2. Implementation of a VAE for the MNIST dataset in TensorFlow
  3. Visualization and interpretation of the latent space
  4. Sampling and generation of new data points
  5. Real-world applications of VAEs
  6. Advanced VAE variants for specific tasks

VAEs represent a powerful bridge between deep learning and probabilistic modeling, enabling both high-quality generation and meaningful latent representations. Unlike GANs, they provide both a generative model and an inference network, making them versatile tools in machine learning.

Additional Resources and Exercises

Resources

  1. TensorFlow VAE Tutorial
  2. Original VAE Paper by Kingma and Welling
  3. Tutorial on Variational Autoencoders

Exercises

  1. Modified Architecture: Experiment with different encoder and decoder architectures. How does the quality of generated samples change?

  2. Different Datasets: Apply the VAE to a different dataset, such as Fashion-MNIST or CIFAR-10. What challenges do you encounter?

  3. Disentanglement: Implement a β-VAE with β > 1 and visualize how different latent dimensions control different features of the generated images.

  4. Conditional Generation: Extend the VAE to a conditional VAE (CVAE) that can generate digits of a specific class.

  5. Anomaly Detection: Use your trained VAE for anomaly detection by measuring reconstruction error on modified MNIST digits (e.g., digits with random noise or structural changes).

  6. Latent Space Arithmetic: Try performing arithmetic in the latent space (e.g., adding and subtracting encoded representations) to achieve semantic operations on the generated samples.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)