TensorFlow MCMC
Introduction
Markov Chain Monte Carlo (MCMC) is a powerful technique used in Bayesian statistics to sample from complex probability distributions. When we have a probability distribution that's difficult to sample from directly, MCMC allows us to draw samples by constructing a Markov chain with the desired distribution as its equilibrium distribution.
TensorFlow Probability (TFP) provides a comprehensive suite of MCMC algorithms that are efficient, parallelizable, and compatible with automatic differentiation. In this tutorial, we'll explore how to use TFP's MCMC tools to solve Bayesian inference problems.
Why MCMC?
Before diving into the code, let's understand why MCMC is necessary:
- Complex Posterior Distributions: In Bayesian inference, we often end up with posterior distributions that don't have analytical solutions
- High Dimensions: Many real-world problems involve high-dimensional parameter spaces
- Integration: MCMC helps compute expectations (integrals) over these complex distributions
Basic MCMC Concepts
An MCMC algorithm typically consists of:
- Target Distribution: The distribution we want to sample from
- Proposal Distribution: A distribution used to generate candidate samples
- Transition Kernel: Rules for accepting or rejecting proposed samples
- Chain: The sequence of accepted samples
Getting Started with TensorFlow Probability MCMC
First, let's import the necessary libraries:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
tfd = tfp.distributions
tfb = tfp.bijectors
A Simple Example: Estimating the Mean of a Normal Distribution
Let's start with a simple example: estimating the mean of a normal distribution given some observed data.
# Generate some synthetic data
true_mean = 2.0
true_std = 1.5
data = np.random.normal(true_mean, true_std, size=1000)
# Define the log probability function
def log_prob_fn(mean, std=true_std):
"""Log probability function for a normal distribution with unknown mean."""
rv = tfd.Normal(loc=mean, scale=std)
return tf.reduce_sum(rv.log_prob(data))
# Wrap it to match the expected signature for MCMC
def target_log_prob_fn(mean):
return log_prob_fn(mean)
Now, let's use the Hamiltonian Monte Carlo (HMC) algorithm, which is one of the most popular MCMC methods:
# Define the initial state
initial_state = tf.constant(0.0) # Start with a guess of 0 for the mean
# Set up the HMC transition kernel
num_results = 1000
num_burnin_steps = 500 # Number of steps to discard (burn-in)
# Step size for the leapfrog integrator
step_size = 0.03
# Define the kernel
kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
num_leapfrog_steps=10,
step_size=step_size)
# Add adaptation for step size
adaptive_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=kernel,
num_adaptation_steps=int(0.8 * num_burnin_steps))
# Run the chain
samples, [is_accepted, step_size_trace] = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=adaptive_kernel,
trace_fn=lambda _, pkr: [pkr.inner_results.is_accepted,
pkr.inner_results.accepted_results.step_size])
Let's analyze the results:
# Print out statistics
print(f"True mean: {true_mean:.4f}")
print(f"Sample mean: {np.mean(samples):.4f}")
print(f"Standard deviation of samples: {np.std(samples):.4f}")
print(f"Acceptance rate: {np.mean(is_accepted):.4f}")
# Plot the histogram of samples
plt.figure(figsize=(10, 6))
plt.hist(samples, bins=30, density=True, alpha=0.7)
plt.axvline(true_mean, color='r', linestyle='--', label=f'True Mean: {true_mean}')
plt.axvline(np.mean(samples), color='g', linestyle='--',
label=f'MCMC Mean: {np.mean(samples):.4f}')
plt.title('Posterior Distribution of Mean')
plt.xlabel('Mean Value')
plt.ylabel('Density')
plt.legend()
plt.show()
Output:
True mean: 2.0000
Sample mean: 1.9946
Standard deviation of samples: 0.0476
Acceptance rate: 0.8050
The histogram would show the posterior distribution centered around the true mean of 2.0.
MCMC Algorithms in TensorFlow Probability
TFP offers several MCMC algorithms, each with its own strengths:
- Random Walk Metropolis (RWM): Simple but can be inefficient
- Hamiltonian Monte Carlo (HMC): Uses gradient information for efficient exploration
- No U-Turn Sampler (NUTS): An adaptive variant of HMC that automatically tunes parameters
- Metropolis-Adjusted Langevin Algorithm (MALA): Combines random walks with gradient information
Let's implement the same example using NUTS:
# Define the NUTS kernel
nuts_kernel = tfp.mcmc.NoUTurnSampler(
target_log_prob_fn=target_log_prob_fn,
step_size=0.03)
# Add adaptation
adaptive_nuts = tfp.mcmc.DualAveragingStepSizeAdaptation(
inner_kernel=nuts_kernel,
num_adaptation_steps=int(0.8 * num_burnin_steps),
target_accept_prob=0.8)
# Run the chain
nuts_samples, nuts_trace = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=adaptive_nuts)
print(f"NUTS sample mean: {np.mean(nuts_samples):.4f}")
print(f"NUTS sample standard deviation: {np.std(nuts_samples):.4f}")
Diagnosing MCMC Convergence
It's important to check if our MCMC chain has converged to the target distribution. Here are some diagnostic tools:
Trace Plots
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(samples)
plt.title('HMC Trace')
plt.xlabel('Iteration')
plt.ylabel('Parameter Value')
plt.subplot(1, 2, 2)
plt.plot(nuts_samples)
plt.title('NUTS Trace')
plt.xlabel('Iteration')
plt.ylabel('Parameter Value')
plt.tight_layout()
plt.show()
Autocorrelation Plots
def autocorrelation(x, max_lag=100):
"""Calculate autocorrelation for a 1D array"""
x = x - np.mean(x)
result = np.correlate(x, x, mode='full')
result = result[len(result)//2:]
result = result[:max_lag]
return result / result[0]
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(autocorrelation(samples))
plt.title('HMC Autocorrelation')
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')
plt.subplot(1, 2, 2)
plt.plot(autocorrelation(nuts_samples))
plt.title('NUTS Autocorrelation')
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')
plt.tight_layout()
plt.show()
Real-World Example: Bayesian Linear Regression
Let's implement a more practical example with Bayesian linear regression:
# Generate synthetic data
n_samples = 100
true_weights = np.array([0.5, 2.0, -1.5]) # Intercept and two features
X = np.random.normal(0, 1, size=(n_samples, 2))
X_with_intercept = np.column_stack([np.ones(n_samples), X])
noise = np.random.normal(0, 0.5, size=n_samples)
y = X_with_intercept @ true_weights + noise
# Define prior distributions
weight_prior_scale = 10.0
noise_prior_scale = 1.0
# Define the model's log probability function
def blr_log_prob_fn(weights, noise_scale=0.5):
"""Log probability function for Bayesian Linear Regression."""
# Prior for weights
prior = tfd.Normal(loc=0., scale=weight_prior_scale)
prior_log_prob = tf.reduce_sum(prior.log_prob(weights))
# Likelihood
y_pred = tf.linalg.matvec(X_with_intercept, weights)
likelihood = tfd.Normal(loc=y_pred, scale=noise_scale)
log_likelihood = tf.reduce_sum(likelihood.log_prob(y))
return prior_log_prob + log_likelihood
# Initial state
initial_weights = tf.zeros(3)
# Set up NUTS sampler
nuts = tfp.mcmc.NoUTurnSampler(
target_log_prob_fn=blr_log_prob_fn,
step_size=0.1)
adaptive_nuts = tfp.mcmc.DualAveragingStepSizeAdaptation(
inner_kernel=nuts,
num_adaptation_steps=800,
target_accept_prob=0.8)
# Run the MCMC chain
num_results = 2000
num_burnin_steps = 1000
weight_samples, _ = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_weights,
kernel=adaptive_nuts,
trace_fn=None)
# Analyze results
weight_posterior_means = tf.reduce_mean(weight_samples, axis=0).numpy()
weight_posterior_stds = tf.math.reduce_std(weight_samples, axis=0).numpy()
print("True weights:", true_weights)
print("Posterior means:", weight_posterior_means)
print("Posterior standard deviations:", weight_posterior_stds)
# Visualize posterior distributions
plt.figure(figsize=(15, 5))
param_names = ['Intercept', 'Weight 1', 'Weight 2']
for i in range(3):
plt.subplot(1, 3, i+1)
plt.hist(weight_samples[:, i], bins=30, density=True, alpha=0.7)
plt.axvline(true_weights[i], color='r', linestyle='--', label=f'True: {true_weights[i]}')
plt.axvline(weight_posterior_means[i], color='g', linestyle='--',
label=f'Mean: {weight_posterior_means[i]:.4f}')
plt.title(f'Posterior for {param_names[i]}')
plt.legend()
plt.tight_layout()
plt.show()
Making Predictions with the Posterior
One of the advantages of Bayesian inference is that we can make predictions that incorporate parameter uncertainty:
# Create test data
X_test = np.random.normal(0, 1, size=(20, 2))
X_test_with_intercept = np.column_stack([np.ones(20), X_test])
# Make predictions using all posterior samples
all_predictions = np.dot(X_test_with_intercept, weight_samples.numpy().T)
# Calculate mean and 95% credible intervals
pred_means = np.mean(all_predictions, axis=1)
pred_lower = np.percentile(all_predictions, 2.5, axis=1)
pred_upper = np.percentile(all_predictions, 97.5, axis=1)
# Calculate true values
true_y = np.dot(X_test_with_intercept, true_weights)
# Plot predictions with uncertainty
plt.figure(figsize=(10, 6))
plt.errorbar(range(20), pred_means, yerr=[pred_means - pred_lower, pred_upper - pred_means],
fmt='o', ecolor='lightgray', capsize=5, label='Predicted with 95% CI')
plt.plot(range(20), true_y, 'rx', label='True Values')
plt.title('Bayesian Linear Regression Predictions with Uncertainty')
plt.xlabel('Test Sample')
plt.ylabel('Predicted Value')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Advanced Topic: Dealing with Constrained Parameters
Sometimes our parameters have constraints (e.g., standard deviation must be positive). TFP allows us to handle this with bijectors:
# Generate some data from a normal distribution
data = np.random.normal(2.0, 1.5, size=1000)
# Define log probability for mean and standard deviation
def constrained_log_prob_fn(mean, log_std):
"""Log probability with a positive constraint on standard deviation."""
# Convert log_std to std (always positive)
std = tf.exp(log_std)
# Prior for mean and std
mean_prior = tfd.Normal(loc=0., scale=10.0)
std_prior = tfd.LogNormal(loc=0., scale=1.0)
prior_log_prob = mean_prior.log_prob(mean) + std_prior.log_prob(std)
# Likelihood
rv = tfd.Normal(loc=mean, scale=std)
log_likelihood = tf.reduce_sum(rv.log_prob(data))
return prior_log_prob + log_likelihood
# Initial state
initial_state = [tf.constant(0.0), tf.constant(0.0)] # [mean, log_std]
# Define the transition kernel with bijector
unconstraining_bijectors = [
tfb.Identity(), # No constraint for mean
tfb.Identity() # log_std is already unconstrained
]
kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=lambda *args: constrained_log_prob_fn(*args),
step_size=0.05,
num_leapfrog_steps=10),
bijector=unconstraining_bijectors)
adaptive_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=kernel,
num_adaptation_steps=800)
# Run the chain
samples, _ = tfp.mcmc.sample_chain(
num_results=2000,
num_burnin_steps=1000,
current_state=initial_state,
kernel=adaptive_kernel,
trace_fn=None)
# Convert samples
mean_samples = samples[0].numpy()
std_samples = np.exp(samples[1].numpy()) # Convert log_std back to std
print(f"True mean: 2.0, Estimated mean: {np.mean(mean_samples):.4f}")
print(f"True std: 1.5, Estimated std: {np.mean(std_samples):.4f}")
# Plot posterior distributions
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(mean_samples, bins=30, density=True)
plt.axvline(2.0, color='r', linestyle='--', label='True Mean')
plt.title('Posterior for Mean')
plt.legend()
plt.subplot(1, 2, 2)
plt.hist(std_samples, bins=30, density=True)
plt.axvline(1.5, color='r', linestyle='--', label='True Std Dev')
plt.title('Posterior for Standard Deviation')
plt.legend()
plt.tight_layout()
plt.show()
Common Challenges and Solutions
1. Poor Mixing
If your chain isn't exploring the parameter space efficiently:
- Increase the number of steps
- Adjust the step size
- Use a more efficient sampler like NUTS
- Reparameterize your model
2. Divergences
Divergences can occur when the step size is too large or the posterior has challenging geometry:
# Monitor divergences with NUTS
divergent_traces = nuts_trace.inner_results.has_divergence
print(f"Number of divergences: {np.sum(divergent_traces)}")
print(f"Divergence rate: {np.mean(divergent_traces):.4f}")
3. Initialization
Proper initialization can speed up convergence:
# Initialize with a reasonable guess (e.g., maximum likelihood estimate)
y_mean = np.mean(y)
X_mean = np.mean(X, axis=0)
initial_guess = np.array([y_mean - np.dot(X_mean, [2.0, -1.5]), 2.0, -1.5])
Summary
In this tutorial, we've explored MCMC methods in TensorFlow Probability for Bayesian inference. We covered:
- The basic concepts of MCMC
- Different MCMC algorithms available in TFP
- How to implement and diagnose MCMC chains
- Real-world applications with Bayesian linear regression
- Advanced topics like handling constrained parameters
MCMC methods are powerful tools for Bayesian inference, allowing us to sample from complex posterior distributions that don't have analytical solutions. TensorFlow Probability makes these methods accessible, efficient, and compatible with modern deep learning workflows.
Additional Resources
- TensorFlow Probability MCMC Documentation
- A Visual Exploration of MCMC Methods
- Probabilistic Programming & Bayesian Methods for Hackers
Exercises
- Modify the Bayesian linear regression example to include more features and analyze how the posterior uncertainty changes.
- Implement a logistic regression model using MCMC for a classification problem.
- Compare the performance of different MCMC algorithms (HMC, NUTS, Random Walk Metropolis) on a simple problem.
- Implement a hierarchical Bayesian model using MCMC in TensorFlow Probability.
- Use MCMC to estimate the parameters of a time series model.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)