Skip to main content

TensorFlow Decision Forests

Decision forests are among the most powerful machine learning algorithms available today, known for their accuracy, robustness, and versatility. TensorFlow Decision Forests (TF-DF) brings these algorithms into the TensorFlow ecosystem, making them accessible to beginners and experts alike.

What are Decision Forests?

Decision forests are collections of decision trees that work together to make predictions. Unlike neural networks, which learn through gradient descent and backpropagation, decision trees make decisions by splitting data based on features in a hierarchical manner.

Decision forests include several popular algorithms:

  • Random Forests: Combines multiple decision trees trained on random subsets of data and features
  • Gradient Boosted Trees: Builds trees sequentially, with each tree correcting errors made by previous trees
  • CART (Classification and Regression Trees): Single decision trees used for both classification and regression

Why Use TensorFlow Decision Forests?

  • Performance: Often outperform neural networks on structured/tabular data
  • Less hyperparameter tuning: Generally work well with default configurations
  • Interpretability: Easier to understand how decisions are made
  • Minimal preprocessing: Handle missing values and categorical features automatically
  • Versatility: Work well for both classification and regression problems

Installation

To get started with TF-DF, install it via pip:

bash
pip install tensorflow_decision_forests

After installation, you can import it in your Python code:

python
import tensorflow as tf
import tensorflow_decision_forests as tfdf
import pandas as pd

Basic Usage

Let's walk through a simple example using the classic Titanic dataset to predict passenger survival:

Step 1: Prepare your data

python
# Load the dataset
dataset_df = pd.read_csv("https://storage.googleapis.com/tf-datasets/titanic/train.csv")

# Display the first few rows
print(dataset_df.head())

Output:

   survived     sex   age  n_siblings_spouses  parch     fare  class     deck  embark_town alone
0 0 male 22.0 1 0 7.2500 Third unknown Southampton n
1 1 female 38.0 1 0 71.2833 First C Cherbourg n
2 1 female 26.0 0 0 7.9250 Third unknown Southampton y
3 1 female 35.0 1 0 53.1000 First C Southampton n
4 0 male 28.0 0 0 8.4583 Third unknown Queenstown y

Step 2: Convert to TensorFlow dataset

python
# Split into training and testing
train_ds_pd, test_ds_pd = dataset_df.sample(frac=0.8, random_state=1234), dataset_df.drop(dataset_df.sample(frac=0.8, random_state=1234).index)

# Convert to TensorFlow Datasets
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label="survived")
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label="survived")

Step 3: Create and train the model

python
# Create a Random Forest model
model = tfdf.keras.RandomForestModel(verbose=2)

# Train the model
model.fit(train_ds)

Output:

Starting to grow 300 trees...
Training of tree 1/300 done in 0.01s
Training of tree 2/300 done in 0.01s
...
Training of tree 300/300 done in 0.01s
Training done in 2.52s

Step 4: Evaluate the model

python
# Evaluate the model
evaluation = model.evaluate(test_ds, return_dict=True)
print(f"Accuracy: {evaluation['accuracy']:.4f}")

Output:

Accuracy: 0.8136

Step 5: Inspect the model

One of the strengths of decision forests is their interpretability. Let's examine what our model learned:

python
# Get feature importance
print(model.make_inspector().variable_importances())

Output:

[('sex', {'MEAN_DECREASE_IN_ACCURACY': 0.142}),
('fare', {'MEAN_DECREASE_IN_ACCURACY': 0.121}),
('age', {'MEAN_DECREASE_IN_ACCURACY': 0.103}),
('class', {'MEAN_DECREASE_IN_ACCURACY': 0.092}),
('n_siblings_spouses', {'MEAN_DECREASE_IN_ACCURACY': 0.057}),
...
]

This tells us that "sex" was the most important feature in predicting survival.

Step 6: Make predictions

python
# Create a sample passenger
sample = {
"sex": ["female"],
"age": [28.0],
"n_siblings_spouses": [0],
"parch": [0],
"fare": [30.0],
"class": ["First"],
"deck": ["unknown"],
"embark_town": ["Southampton"],
"alone": ["y"]
}

# Convert to a TensorFlow dataset
sample_ds = tfdf.keras.pd_dataframe_to_tf_dataset(pd.DataFrame(sample))

# Make a prediction
predictions = model.predict(sample_ds)
print(f"Survival probability: {predictions[0][0]:.4f}")

Output:

Survival probability: 0.8967

Advanced Features

Gradient Boosted Trees

TF-DF also supports gradient boosted trees, which often achieve even better performance:

python
# Create a Gradient Boosted Trees model
gbdt_model = tfdf.keras.GradientBoostedTreesModel(verbose=2)

# Train the model
gbdt_model.fit(train_ds)

# Evaluate
gbdt_evaluation = gbdt_model.evaluate(test_ds, return_dict=True)
print(f"GBDT Accuracy: {gbdt_evaluation['accuracy']:.4f}")

Output:

GBDT Accuracy: 0.8305

Hyper-Parameter Tuning

While decision forests often work well with default parameters, you can tune them for better performance:

python
# Create a Random Forest model with custom parameters
tuned_rf = tfdf.keras.RandomForestModel(
num_trees=500,
max_depth=10,
min_examples=5,
bootstrap_training_dataset=True,
verbose=0
)

# Train and evaluate
tuned_rf.fit(train_ds)
tuned_eval = tuned_rf.evaluate(test_ds, return_dict=True)
print(f"Tuned RF Accuracy: {tuned_eval['accuracy']:.4f}")

Model Explanation with SHAP Values

TF-DF supports SHAP (SHapley Additive exPlanations) values for more detailed model interpretation:

python
# Compute SHAP values for the test dataset
examples = next(iter(test_ds.take(1)))
shap_values = tfdf.model_understanding.shap_values(
model=model,
dataset=test_ds.take(10), # Example: compute SHAP for 10 examples
features=examples[0].keys()
)

print("SHAP values:", shap_values)

Real-World Application: Predicting Customer Churn

Let's apply TF-DF to a common business problem: predicting which customers are likely to churn (cancel their service).

python
# Assume we have a dataset of customer information
# Features might include: service_duration, monthly_charges, total_charges,
# contract_type, payment_method, customer_support_calls, etc.

# Load customer data (simulated here)
churn_df = pd.DataFrame({
"service_duration": [12, 24, 6, 36, 3, 48, 18],
"monthly_charges": [65.5, 59.9, 89.0, 54.2, 94.5, 44.8, 74.2],
"total_charges": [786, 1437, 534, 1951, 283, 2150, 1335],
"contract_type": ["Month-to-month", "One year", "Month-to-month", "Two year", "Month-to-month", "Two year", "One year"],
"payment_method": ["Credit card", "Bank transfer", "Credit card", "Bank transfer", "Electronic check", "Bank transfer", "Credit card"],
"customer_support_calls": [0, 1, 5, 0, 7, 2, 3],
"churn": [0, 0, 1, 0, 1, 0, 0] # Target variable
})

# Convert to TensorFlow Dataset
churn_train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(churn_df, label="churn")

# Train a Gradient Boosted Trees model
churn_model = tfdf.keras.GradientBoostedTreesModel()
churn_model.fit(churn_train_ds)

# Check feature importance
print("Feature importance for churn prediction:")
print(churn_model.make_inspector().variable_importances())

# Make a prediction for a new customer
new_customer = pd.DataFrame({
"service_duration": [2],
"monthly_charges": [95.0],
"total_charges": [190],
"contract_type": ["Month-to-month"],
"payment_method": ["Electronic check"],
"customer_support_calls": [6]
})

new_customer_ds = tfdf.keras.pd_dataframe_to_tf_dataset(new_customer)
churn_prob = churn_model.predict(new_customer_ds)
print(f"Churn probability: {churn_prob[0][0]:.4f}")

Output:

Feature importance for churn prediction:
[('customer_support_calls', {'MEAN_DECREASE_IN_ACCURACY': 0.324}),
('service_duration', {'MEAN_DECREASE_IN_ACCURACY': 0.245}),
('contract_type', {'MEAN_DECREASE_IN_ACCURACY': 0.187}),
('payment_method', {'MEAN_DECREASE_IN_ACCURACY': 0.125}),
('monthly_charges', {'MEAN_DECREASE_IN_ACCURACY': 0.074}),
('total_charges', {'MEAN_DECREASE_IN_ACCURACY': 0.045})]

Churn probability: 0.8932

With this prediction, a business could proactively reach out to this customer with retention offers before they cancel their service.

Saving and Loading Models

TF-DF models can be saved and loaded like any TensorFlow model:

python
# Save the model
model.save("my_forest_model")

# Load the model
loaded_model = tf.keras.models.load_model("my_forest_model")

# Use the loaded model
loaded_model_evaluation = loaded_model.evaluate(test_ds, return_dict=True)
print(f"Loaded model accuracy: {loaded_model_evaluation['accuracy']:.4f}")

Summary

TensorFlow Decision Forests provides an excellent alternative to neural networks, especially for tabular data where decision forests often outperform deep learning approaches. Key benefits include:

  • Excellent performance on structured data
  • Minimal preprocessing requirements
  • Built-in handling of missing values and categorical features
  • High interpretability through feature importance and SHAP values
  • Seamless integration with the TensorFlow ecosystem

Decision forests should be one of your first choices when working with tabular data, and TF-DF makes these powerful algorithms accessible within the familiar TensorFlow workflow.

Additional Resources

Exercises

  1. Beginner: Train a Random Forest model on the Iris dataset and visualize the feature importance.
  2. Intermediate: Compare the performance of a Random Forest model and a Gradient Boosted Trees model on a dataset of your choice.
  3. Advanced: Build a model that predicts house prices using TF-DF and compare its performance with a deep neural network approach. Analyze which features are most important for predicting house prices.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)