TensorFlow Causal Inference
Introduction
Causal inference is a critical field in data science that aims to understand cause and effect relationships beyond mere correlations. While traditional machine learning focuses on prediction, causal inference allows us to answer questions like "What would happen if we intervene?" or "Why did this outcome occur?"
TensorFlow Probability (TFP) provides powerful tools for implementing causal inference methods within the TensorFlow ecosystem. In this tutorial, we'll explore how to use TensorFlow for causal inference tasks, understand the fundamentals of causal reasoning, and implement practical examples that demonstrate these concepts.
What is Causal Inference?
Causal inference refers to the process of determining the effect of actions or treatments on outcomes. Unlike predictive modeling that focuses on correlations, causal inference aims to understand:
- Causation vs. Correlation: Distinguishing between "X is associated with Y" and "X causes Y"
- Counterfactual Reasoning: What would have happened if an alternative action had been taken
- Interventions: Understanding the effects of changing variables in a system
Causal Models in TensorFlow Probability
TensorFlow Probability provides several ways to model causal relationships:
1. Structural Causal Models (SCMs)
Structural Causal Models represent causal relationships using directed graphs where nodes are variables and edges represent direct causal effects.
Let's implement a simple SCM using TensorFlow Probability:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import numpy as np
tfd = tfp.distributions
# Set random seed for reproducibility
tf.random.set_seed(42)
# Define the causal model: X -> Y <- Z
# X has a direct effect on Y
# Z has a direct effect on Y
# X and Z are independent
# Number of samples
n_samples = 1000
# Generate X and Z (independent causes)
X = tfd.Normal(loc=0.0, scale=1.0).sample(n_samples)
Z = tfd.Normal(loc=0.0, scale=1.0).sample(n_samples)
# Generate Y (effect) based on X and Z
# Y = 2*X - 1.5*Z + noise
noise = tfd.Normal(loc=0.0, scale=0.5).sample(n_samples)
Y = 2.0 * X - 1.5 * Z + noise
# Convert to numpy for visualization
X_np = X.numpy()
Y_np = Y.numpy()
Z_np = Z.numpy()
# Visualize relationships
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.scatter(X_np, Y_np, alpha=0.5)
plt.title('X vs Y')
plt.xlabel('X')
plt.ylabel('Y')
plt.subplot(1, 3, 2)
plt.scatter(Z_np, Y_np, alpha=0.5)
plt.title('Z vs Y')
plt.xlabel('Z')
plt.ylabel('Y')
plt.subplot(1, 3, 3)
plt.scatter(X_np, Z_np, alpha=0.5)
plt.title('X vs Z')
plt.xlabel('X')
plt.ylabel('Z')
plt.tight_layout()
plt.show()
Output: The above code will generate three scatter plots showing the relationships between X and Y (positive correlation), Z and Y (negative correlation), and X and Z (no correlation).
2. Using Do-Calculus for Interventions
In causal inference, we often want to model interventions - what happens when we forcibly set a variable to a specific value. This is represented by the do-operator: do(X=x).
Let's implement an intervention using our model:
# Function to simulate intervention do(X=x)
def intervene_on_X(x_value, z_samples=None, n_samples=1000):
"""Simulate the effect of setting X=x_value"""
if z_samples is None:
z_samples = tfd.Normal(loc=0.0, scale=1.0).sample(n_samples)
# Set X to the intervention value
x_samples = tf.fill([n_samples], x_value)
# Generate Y based on intervention
noise = tfd.Normal(loc=0.0, scale=0.5).sample(n_samples)
y_samples = 2.0 * x_samples - 1.5 * z_samples + noise
return x_samples, z_samples, y_samples
# Compare distributions of Y under different interventions on X
interventions = [-2.0, 0.0, 2.0]
z_fixed = tfd.Normal(loc=0.0, scale=1.0).sample(n_samples) # Fix Z samples for fair comparison
plt.figure(figsize=(10, 6))
for x_val in interventions:
_, _, y_samples = intervene_on_X(x_val, z_samples=z_fixed)
plt.hist(y_samples.numpy(), alpha=0.5, bins=30, label=f'do(X={x_val})')
plt.title('Distribution of Y under different interventions on X')
plt.xlabel('Y')
plt.ylabel('Frequency')
plt.legend()
plt.show()
Output: This code will display histograms showing how the distribution of Y changes when we intervene to set X to different values.
Estimating Causal Effects
Average Treatment Effect (ATE)
A common goal in causal inference is estimating the average treatment effect (ATE). Let's implement a simple example using TensorFlow:
# Generate data for a simple treatment effect example
# T: Treatment (0 or 1)
# X: Confounder that affects both treatment assignment and outcome
# Y: Outcome
n_samples = 5000
tf.random.set_seed(123)
# Generate confounder X
X = tfd.Normal(loc=0.0, scale=1.0).sample(n_samples)
# Treatment assignment depends on X (confounding)
logits = 1.5 * X
T_prob = tf.sigmoid(logits)
T = tfd.Bernoulli(probs=T_prob).sample()
# Outcome depends on both T and X
Y = 2.0 * T + 1.5 * X + tfd.Normal(loc=0.0, scale=1.0).sample(n_samples)
# Naive estimate of treatment effect (ignoring confounding)
treated_outcomes = tf.boolean_mask(Y, tf.cast(T, tf.bool))
control_outcomes = tf.boolean_mask(Y, tf.logical_not(tf.cast(T, tf.bool)))
naive_ate = tf.reduce_mean(treated_outcomes) - tf.reduce_mean(control_outcomes)
print(f"Naive ATE estimate (biased): {naive_ate.numpy():.4f}")
print(f"True ATE: 2.0000")
Output:
Naive ATE estimate (biased): 3.1246
True ATE: 2.0000
As you can see, the naive estimate is biased due to confounding. Let's correct this using a simple adjustment method.
Controlling for Confounders
We'll implement a simple stratification method to adjust for the confounder:
# Convert tensors to numpy for easier manipulation
X_np = X.numpy()
T_np = T.numpy()
Y_np = Y.numpy()
# Create quintiles of X for stratification
n_strata = 5
percentiles = np.percentile(X_np, np.linspace(0, 100, n_strata+1))
# Calculate stratum-specific treatment effects
stratum_effects = []
stratum_weights = []
for i in range(n_strata):
lower = percentiles[i]
upper = percentiles[i+1]
# Select samples in this stratum
in_stratum = (X_np >= lower) & (X_np < upper)
if i == n_strata - 1: # Include upper bound in last stratum
in_stratum = (X_np >= lower) & (X_np <= upper)
# Get treatment and control outcomes in this stratum
stratum_T = T_np[in_stratum]
stratum_Y = Y_np[in_stratum]
treated = stratum_Y[stratum_T == 1]
control = stratum_Y[stratum_T == 0]
if len(treated) > 0 and len(control) > 0:
# Calculate stratum-specific ATE
stratum_ate = np.mean(treated) - np.mean(control)
stratum_effects.append(stratum_ate)
stratum_weights.append(len(in_stratum[in_stratum]))
# Calculate weighted average of stratum-specific effects
adjusted_ate = np.average(stratum_effects, weights=stratum_weights)
print(f"Adjusted ATE estimate: {adjusted_ate:.4f}")
print(f"True ATE: 2.0000")
Output:
Adjusted ATE estimate: 2.0412
True ATE: 2.0000
The adjusted estimate is much closer to the true ATE of 2.0.
Advanced: Causal Inference with TFP Models
Now let's implement a more advanced causal inference technique using TensorFlow Probability's Bayesian modeling capabilities.
Propensity Score Matching
Propensity score matching is a popular technique for causal inference. Let's implement it:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
tfd = tfp.distributions
tfb = tfp.bijectors
# Generate synthetic data
n_samples = 2000
tf.random.set_seed(42)
# Features that affect both treatment and outcome
X1 = tfd.Normal(loc=0.0, scale=1.0).sample(n_samples)
X2 = tfd.Normal(loc=0.0, scale=1.0).sample(n_samples)
# Treatment assignment based on features
logits = 0.5 * X1 - 0.8 * X2
propensity = tf.sigmoid(logits)
treatment = tfd.Bernoulli(probs=propensity).sample()
# True treatment effect is 3.0
true_effect = 3.0
# Outcome depends on features and treatment
outcome = (
2.0 * X1 +
1.5 * X2 +
true_effect * treatment +
tfd.Normal(loc=0.0, scale=1.0).sample(n_samples)
)
# Define a propensity score model with TFP
def build_propensity_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(2,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
loss='binary_crossentropy'
)
return model
# Train propensity model
features = tf.stack([X1, X2], axis=1)
prop_model = build_propensity_model()
prop_model.fit(features, treatment, epochs=5, verbose=0)
# Get predicted propensity scores
propensity_scores = prop_model.predict(features)[:, 0]
# Implement 1:1 nearest neighbor matching
def match_samples(prop_scores, treatment, outcome, features):
treated_indices = np.where(treatment.numpy() == 1)[0]
control_indices = np.where(treatment.numpy() == 0)[0]
treated_props = prop_scores[treated_indices]
control_props = prop_scores[control_indices]
# For each treated unit, find closest control unit
matched_pairs = []
for i, t_idx in enumerate(treated_indices):
t_prop = prop_scores[t_idx]
distances = np.abs(control_props - t_prop)
closest_control_idx = control_indices[np.argmin(distances)]
matched_pairs.append((t_idx, closest_control_idx))
return matched_pairs
# Perform matching
matched_pairs = match_samples(propensity_scores, treatment, outcome, features)
# Calculate treatment effect from matched pairs
treatment_effects = []
for treated_idx, control_idx in matched_pairs:
effect = outcome[treated_idx].numpy() - outcome[control_idx].numpy()
treatment_effects.append(effect)
# Estimate ATE
estimated_ate = np.mean(treatment_effects)
print(f"Propensity Score Matching ATE Estimate: {estimated_ate:.4f}")
print(f"True ATE: {true_effect:.4f}")
Output:
Propensity Score Matching ATE Estimate: 3.0215
True ATE: 3.0000
Real-World Application: Marketing Campaign Effect
Let's apply causal inference to a real-world scenario: estimating the effect of a marketing campaign on customer purchases.
# Simulating a marketing campaign dataset
n_customers = 3000
tf.random.set_seed(123)
# Customer features
age = tfd.Normal(loc=35, scale=10).sample(n_customers)
income = tfd.LogNormal(loc=10.5, scale=0.4).sample(n_customers)
previous_purchases = tfd.Poisson(rate=5.0).sample(n_customers)
# Normalize features
age_norm = (age - tf.reduce_mean(age)) / tf.math.reduce_std(age)
income_norm = (income - tf.reduce_mean(income)) / tf.math.reduce_std(income)
prev_purch_norm = (previous_purchases - tf.reduce_mean(previous_purchases)) / tf.math.reduce_std(previous_purchases)
# Campaign targeting (more likely for high-income, high previous purchase customers)
campaign_logits = 0.5 + 0.8 * income_norm + 0.7 * prev_purch_norm - 0.2 * age_norm
campaign_probs = tf.sigmoid(campaign_logits)
received_campaign = tfd.Bernoulli(probs=campaign_probs).sample()
# True campaign effect (heterogeneous)
base_effect = 2.0
age_interaction = -0.2 # Campaign is less effective for older customers
income_interaction = 0.3 # Campaign is more effective for higher-income customers
# Individual treatment effects
individual_effects = base_effect + age_interaction * age_norm + income_interaction * income_norm
# Purchase amount depends on features and campaign
purchase_amount = (
20.0 + # Base amount
5.0 * income_norm + # Income effect
3.0 * prev_purch_norm + # Previous purchase effect
-1.0 * age_norm + # Age effect
individual_effects * received_campaign + # Campaign effect
tfd.Normal(loc=0.0, scale=5.0).sample(n_customers) # Random noise
)
# Prepare features for modeling
features_df = tf.stack([age_norm, income_norm, prev_purch_norm], axis=1)
# Create a double machine learning model
def build_ml_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(20, activation='relu'),
tf.keras.layers.Dense(10, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
return model
# First stage: predict treatment
treatment_model = build_ml_model()
treatment_model.fit(features_df, received_campaign, epochs=5, verbose=0)
treatment_pred = treatment_model.predict(features_df)[:, 0]
# Second stage: predict outcome
outcome_model = build_ml_model()
outcome_model.fit(features_df, purchase_amount, epochs=5, verbose=0)
outcome_pred = outcome_model.predict(features_df)[:, 0]
# Calculate residuals
treatment_resid = received_campaign.numpy() - treatment_pred
outcome_resid = purchase_amount.numpy() - outcome_pred
# Estimate treatment effect using residuals
import statsmodels.api as sm
# Use statsmodels to run the final regression
X = sm.add_constant(treatment_resid)
model = sm.OLS(outcome_resid, X)
results = model.fit()
# Print results
print("\nMarketing Campaign Effect Estimation:")
print(f"Estimated Average Treatment Effect: ${results.params[1]:.2f}")
print(f"95% Confidence Interval: (${results.conf_int()[1][0]:.2f}, ${results.conf_int()[1][1]:.2f})")
print(f"True Average Treatment Effect: ${tf.reduce_mean(individual_effects).numpy():.2f}")
Output:
Marketing Campaign Effect Estimation:
Estimated Average Treatment Effect: $2.07
95% Confidence Interval: ($1.58, $2.57)
True Average Treatment Effect: $2.00
Summary
In this tutorial, you've learned:
- The foundations of causal inference and how it differs from predictive modeling
- How to implement structural causal models using TensorFlow Probability
- Methods to estimate causal effects like Average Treatment Effect (ATE)
- Advanced techniques such as propensity score matching
- A real-world application for marketing campaign evaluation
Causal inference is a powerful approach for going beyond prediction to understand the impact of actions and interventions. With TensorFlow Probability, you can build sophisticated causal models that leverage the power of deep learning while maintaining causal interpretability.
Additional Resources
- TensorFlow Probability Documentation
- Causal Inference: The Mixtape - Excellent free textbook on causal inference
- Elements of Causal Inference by Jonas Peters, Dominik Janzing, and Bernhard Schölkopf
- Introduction to Causal Inference from a Machine Learning Perspective - Online course by Brady Neal
Exercises
- Modify the structural causal model to include a confounder that affects both X and Y. How does this change the relationships?
- Implement the inverse probability weighting (IPW) method for causal inference using TensorFlow.
- Extend the marketing campaign analysis to include heterogeneous treatment effects for different customer segments.
- Create a causal model for a healthcare scenario where you want to estimate the effect of a medication on patient outcomes.
- Implement a sensitivity analysis to evaluate how robust your causal conclusions are to unobserved confounding.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)