Skip to main content

TensorFlow MCMC

Introduction

Markov Chain Monte Carlo (MCMC) is a powerful technique used in Bayesian statistics to sample from complex probability distributions. When we have a probability distribution that's difficult to sample from directly, MCMC allows us to draw samples by constructing a Markov chain with the desired distribution as its equilibrium distribution.

TensorFlow Probability (TFP) provides a comprehensive suite of MCMC algorithms that are efficient, parallelizable, and compatible with automatic differentiation. In this tutorial, we'll explore how to use TFP's MCMC tools to solve Bayesian inference problems.

Why MCMC?

Before diving into the code, let's understand why MCMC is necessary:

  1. Complex Posterior Distributions: In Bayesian inference, we often end up with posterior distributions that don't have analytical solutions
  2. High Dimensions: Many real-world problems involve high-dimensional parameter spaces
  3. Integration: MCMC helps compute expectations (integrals) over these complex distributions

Basic MCMC Concepts

An MCMC algorithm typically consists of:

  1. Target Distribution: The distribution we want to sample from
  2. Proposal Distribution: A distribution used to generate candidate samples
  3. Transition Kernel: Rules for accepting or rejecting proposed samples
  4. Chain: The sequence of accepted samples

Getting Started with TensorFlow Probability MCMC

First, let's import the necessary libraries:

python
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt

tfd = tfp.distributions
tfb = tfp.bijectors

A Simple Example: Estimating the Mean of a Normal Distribution

Let's start with a simple example: estimating the mean of a normal distribution given some observed data.

python
# Generate some synthetic data
true_mean = 2.0
true_std = 1.5
data = np.random.normal(true_mean, true_std, size=1000)

# Define the log probability function
def log_prob_fn(mean, std=true_std):
"""Log probability function for a normal distribution with unknown mean."""
rv = tfd.Normal(loc=mean, scale=std)
return tf.reduce_sum(rv.log_prob(data))

# Wrap it to match the expected signature for MCMC
def target_log_prob_fn(mean):
return log_prob_fn(mean)

Now, let's use the Hamiltonian Monte Carlo (HMC) algorithm, which is one of the most popular MCMC methods:

python
# Define the initial state
initial_state = tf.constant(0.0) # Start with a guess of 0 for the mean

# Set up the HMC transition kernel
num_results = 1000
num_burnin_steps = 500 # Number of steps to discard (burn-in)

# Step size for the leapfrog integrator
step_size = 0.03

# Define the kernel
kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
num_leapfrog_steps=10,
step_size=step_size)

# Add adaptation for step size
adaptive_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=kernel,
num_adaptation_steps=int(0.8 * num_burnin_steps))

# Run the chain
samples, [is_accepted, step_size_trace] = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=adaptive_kernel,
trace_fn=lambda _, pkr: [pkr.inner_results.is_accepted,
pkr.inner_results.accepted_results.step_size])

Let's analyze the results:

python
# Print out statistics
print(f"True mean: {true_mean:.4f}")
print(f"Sample mean: {np.mean(samples):.4f}")
print(f"Standard deviation of samples: {np.std(samples):.4f}")
print(f"Acceptance rate: {np.mean(is_accepted):.4f}")

# Plot the histogram of samples
plt.figure(figsize=(10, 6))
plt.hist(samples, bins=30, density=True, alpha=0.7)
plt.axvline(true_mean, color='r', linestyle='--', label=f'True Mean: {true_mean}')
plt.axvline(np.mean(samples), color='g', linestyle='--',
label=f'MCMC Mean: {np.mean(samples):.4f}')
plt.title('Posterior Distribution of Mean')
plt.xlabel('Mean Value')
plt.ylabel('Density')
plt.legend()
plt.show()

Output:

True mean: 2.0000
Sample mean: 1.9946
Standard deviation of samples: 0.0476
Acceptance rate: 0.8050

The histogram would show the posterior distribution centered around the true mean of 2.0.

MCMC Algorithms in TensorFlow Probability

TFP offers several MCMC algorithms, each with its own strengths:

  1. Random Walk Metropolis (RWM): Simple but can be inefficient
  2. Hamiltonian Monte Carlo (HMC): Uses gradient information for efficient exploration
  3. No U-Turn Sampler (NUTS): An adaptive variant of HMC that automatically tunes parameters
  4. Metropolis-Adjusted Langevin Algorithm (MALA): Combines random walks with gradient information

Let's implement the same example using NUTS:

python
# Define the NUTS kernel
nuts_kernel = tfp.mcmc.NoUTurnSampler(
target_log_prob_fn=target_log_prob_fn,
step_size=0.03)

# Add adaptation
adaptive_nuts = tfp.mcmc.DualAveragingStepSizeAdaptation(
inner_kernel=nuts_kernel,
num_adaptation_steps=int(0.8 * num_burnin_steps),
target_accept_prob=0.8)

# Run the chain
nuts_samples, nuts_trace = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=adaptive_nuts)

print(f"NUTS sample mean: {np.mean(nuts_samples):.4f}")
print(f"NUTS sample standard deviation: {np.std(nuts_samples):.4f}")

Diagnosing MCMC Convergence

It's important to check if our MCMC chain has converged to the target distribution. Here are some diagnostic tools:

Trace Plots

python
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(samples)
plt.title('HMC Trace')
plt.xlabel('Iteration')
plt.ylabel('Parameter Value')

plt.subplot(1, 2, 2)
plt.plot(nuts_samples)
plt.title('NUTS Trace')
plt.xlabel('Iteration')
plt.ylabel('Parameter Value')
plt.tight_layout()
plt.show()

Autocorrelation Plots

python
def autocorrelation(x, max_lag=100):
"""Calculate autocorrelation for a 1D array"""
x = x - np.mean(x)
result = np.correlate(x, x, mode='full')
result = result[len(result)//2:]
result = result[:max_lag]
return result / result[0]

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(autocorrelation(samples))
plt.title('HMC Autocorrelation')
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')

plt.subplot(1, 2, 2)
plt.plot(autocorrelation(nuts_samples))
plt.title('NUTS Autocorrelation')
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')
plt.tight_layout()
plt.show()

Real-World Example: Bayesian Linear Regression

Let's implement a more practical example with Bayesian linear regression:

python
# Generate synthetic data
n_samples = 100
true_weights = np.array([0.5, 2.0, -1.5]) # Intercept and two features
X = np.random.normal(0, 1, size=(n_samples, 2))
X_with_intercept = np.column_stack([np.ones(n_samples), X])
noise = np.random.normal(0, 0.5, size=n_samples)
y = X_with_intercept @ true_weights + noise

# Define prior distributions
weight_prior_scale = 10.0
noise_prior_scale = 1.0

# Define the model's log probability function
def blr_log_prob_fn(weights, noise_scale=0.5):
"""Log probability function for Bayesian Linear Regression."""
# Prior for weights
prior = tfd.Normal(loc=0., scale=weight_prior_scale)
prior_log_prob = tf.reduce_sum(prior.log_prob(weights))

# Likelihood
y_pred = tf.linalg.matvec(X_with_intercept, weights)
likelihood = tfd.Normal(loc=y_pred, scale=noise_scale)
log_likelihood = tf.reduce_sum(likelihood.log_prob(y))

return prior_log_prob + log_likelihood

# Initial state
initial_weights = tf.zeros(3)

# Set up NUTS sampler
nuts = tfp.mcmc.NoUTurnSampler(
target_log_prob_fn=blr_log_prob_fn,
step_size=0.1)

adaptive_nuts = tfp.mcmc.DualAveragingStepSizeAdaptation(
inner_kernel=nuts,
num_adaptation_steps=800,
target_accept_prob=0.8)

# Run the MCMC chain
num_results = 2000
num_burnin_steps = 1000

weight_samples, _ = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_weights,
kernel=adaptive_nuts,
trace_fn=None)

# Analyze results
weight_posterior_means = tf.reduce_mean(weight_samples, axis=0).numpy()
weight_posterior_stds = tf.math.reduce_std(weight_samples, axis=0).numpy()

print("True weights:", true_weights)
print("Posterior means:", weight_posterior_means)
print("Posterior standard deviations:", weight_posterior_stds)

# Visualize posterior distributions
plt.figure(figsize=(15, 5))
param_names = ['Intercept', 'Weight 1', 'Weight 2']

for i in range(3):
plt.subplot(1, 3, i+1)
plt.hist(weight_samples[:, i], bins=30, density=True, alpha=0.7)
plt.axvline(true_weights[i], color='r', linestyle='--', label=f'True: {true_weights[i]}')
plt.axvline(weight_posterior_means[i], color='g', linestyle='--',
label=f'Mean: {weight_posterior_means[i]:.4f}')
plt.title(f'Posterior for {param_names[i]}')
plt.legend()

plt.tight_layout()
plt.show()

Making Predictions with the Posterior

One of the advantages of Bayesian inference is that we can make predictions that incorporate parameter uncertainty:

python
# Create test data
X_test = np.random.normal(0, 1, size=(20, 2))
X_test_with_intercept = np.column_stack([np.ones(20), X_test])

# Make predictions using all posterior samples
all_predictions = np.dot(X_test_with_intercept, weight_samples.numpy().T)

# Calculate mean and 95% credible intervals
pred_means = np.mean(all_predictions, axis=1)
pred_lower = np.percentile(all_predictions, 2.5, axis=1)
pred_upper = np.percentile(all_predictions, 97.5, axis=1)

# Calculate true values
true_y = np.dot(X_test_with_intercept, true_weights)

# Plot predictions with uncertainty
plt.figure(figsize=(10, 6))
plt.errorbar(range(20), pred_means, yerr=[pred_means - pred_lower, pred_upper - pred_means],
fmt='o', ecolor='lightgray', capsize=5, label='Predicted with 95% CI')
plt.plot(range(20), true_y, 'rx', label='True Values')
plt.title('Bayesian Linear Regression Predictions with Uncertainty')
plt.xlabel('Test Sample')
plt.ylabel('Predicted Value')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Advanced Topic: Dealing with Constrained Parameters

Sometimes our parameters have constraints (e.g., standard deviation must be positive). TFP allows us to handle this with bijectors:

python
# Generate some data from a normal distribution
data = np.random.normal(2.0, 1.5, size=1000)

# Define log probability for mean and standard deviation
def constrained_log_prob_fn(mean, log_std):
"""Log probability with a positive constraint on standard deviation."""
# Convert log_std to std (always positive)
std = tf.exp(log_std)

# Prior for mean and std
mean_prior = tfd.Normal(loc=0., scale=10.0)
std_prior = tfd.LogNormal(loc=0., scale=1.0)

prior_log_prob = mean_prior.log_prob(mean) + std_prior.log_prob(std)

# Likelihood
rv = tfd.Normal(loc=mean, scale=std)
log_likelihood = tf.reduce_sum(rv.log_prob(data))

return prior_log_prob + log_likelihood

# Initial state
initial_state = [tf.constant(0.0), tf.constant(0.0)] # [mean, log_std]

# Define the transition kernel with bijector
unconstraining_bijectors = [
tfb.Identity(), # No constraint for mean
tfb.Identity() # log_std is already unconstrained
]

kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=lambda *args: constrained_log_prob_fn(*args),
step_size=0.05,
num_leapfrog_steps=10),
bijector=unconstraining_bijectors)

adaptive_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=kernel,
num_adaptation_steps=800)

# Run the chain
samples, _ = tfp.mcmc.sample_chain(
num_results=2000,
num_burnin_steps=1000,
current_state=initial_state,
kernel=adaptive_kernel,
trace_fn=None)

# Convert samples
mean_samples = samples[0].numpy()
std_samples = np.exp(samples[1].numpy()) # Convert log_std back to std

print(f"True mean: 2.0, Estimated mean: {np.mean(mean_samples):.4f}")
print(f"True std: 1.5, Estimated std: {np.mean(std_samples):.4f}")

# Plot posterior distributions
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(mean_samples, bins=30, density=True)
plt.axvline(2.0, color='r', linestyle='--', label='True Mean')
plt.title('Posterior for Mean')
plt.legend()

plt.subplot(1, 2, 2)
plt.hist(std_samples, bins=30, density=True)
plt.axvline(1.5, color='r', linestyle='--', label='True Std Dev')
plt.title('Posterior for Standard Deviation')
plt.legend()
plt.tight_layout()
plt.show()

Common Challenges and Solutions

1. Poor Mixing

If your chain isn't exploring the parameter space efficiently:

  • Increase the number of steps
  • Adjust the step size
  • Use a more efficient sampler like NUTS
  • Reparameterize your model

2. Divergences

Divergences can occur when the step size is too large or the posterior has challenging geometry:

python
# Monitor divergences with NUTS
divergent_traces = nuts_trace.inner_results.has_divergence

print(f"Number of divergences: {np.sum(divergent_traces)}")
print(f"Divergence rate: {np.mean(divergent_traces):.4f}")

3. Initialization

Proper initialization can speed up convergence:

python
# Initialize with a reasonable guess (e.g., maximum likelihood estimate)
y_mean = np.mean(y)
X_mean = np.mean(X, axis=0)
initial_guess = np.array([y_mean - np.dot(X_mean, [2.0, -1.5]), 2.0, -1.5])

Summary

In this tutorial, we've explored MCMC methods in TensorFlow Probability for Bayesian inference. We covered:

  1. The basic concepts of MCMC
  2. Different MCMC algorithms available in TFP
  3. How to implement and diagnose MCMC chains
  4. Real-world applications with Bayesian linear regression
  5. Advanced topics like handling constrained parameters

MCMC methods are powerful tools for Bayesian inference, allowing us to sample from complex posterior distributions that don't have analytical solutions. TensorFlow Probability makes these methods accessible, efficient, and compatible with modern deep learning workflows.

Additional Resources

Exercises

  1. Modify the Bayesian linear regression example to include more features and analyze how the posterior uncertainty changes.
  2. Implement a logistic regression model using MCMC for a classification problem.
  3. Compare the performance of different MCMC algorithms (HMC, NUTS, Random Walk Metropolis) on a simple problem.
  4. Implement a hierarchical Bayesian model using MCMC in TensorFlow Probability.
  5. Use MCMC to estimate the parameters of a time series model.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)