TensorFlow Custom Metrics
Introduction
When training machine learning models, measuring performance is crucial. While TensorFlow provides many built-in metrics like accuracy, precision, and recall, you may sometimes need specialized measurements for your specific problem. This is where custom metrics come in.
Custom metrics allow you to define exactly how you want to evaluate your model's performance. Whether you're working with unique data distributions, specialized domains like healthcare or finance, or simply need a measurement that isn't available out-of-the-box, TensorFlow provides a flexible framework for creating your own evaluation metrics.
In this tutorial, we'll explore:
- What metrics are in TensorFlow
- When and why to use custom metrics
- How to create custom metrics in different ways
- Practical examples of custom metrics
- Best practices for implementing your own metrics
Understanding TensorFlow Metrics
Before diving into custom metrics, let's understand what metrics are in TensorFlow.
Metrics in TensorFlow measure how well your model performs. They track values during training and testing, accumulating data to provide statistical measures of model performance. Unlike losses which are used to train the model, metrics primarily serve to evaluate model performance.
# Example of built-in metrics
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'] # Built-in metric
)
TensorFlow implements metrics as classes that inherit from tf.keras.metrics.Metric
. Each metric maintains state variables to track statistics across batches and provides methods to update these values and return results.
When to Create Custom Metrics
You might need custom metrics when:
- The built-in metrics don't capture the nuances of your problem
- You need domain-specific performance measures
- You're implementing a novel metric from research
- You need to modify an existing metric behavior
- You need to combine multiple metrics into a single value
Creating Custom Metrics in TensorFlow
There are three primary ways to create custom metrics in TensorFlow:
- Creating a function that takes
y_true
andy_pred
as inputs - Subclassing
tf.keras.metrics.Metric
- Using
tf.keras.metrics.MeanMetricWrapper
for simple cases
Let's examine each approach.
1. Custom Metric Functions
The simplest way to create a custom metric is to define a function:
import tensorflow as tf
def custom_accuracy(y_true, y_pred):
# Convert probabilities to class predictions
y_pred_classes = tf.argmax(y_pred, axis=1)
y_true_classes = tf.argmax(y_true, axis=1)
# Calculate accuracy
equality = tf.equal(y_pred_classes, y_true_classes)
accuracy = tf.reduce_mean(tf.cast(equality, tf.float32))
return accuracy
# Use in model compilation
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=[custom_accuracy]
)
This approach works well for simple metrics that don't need to maintain state across batches.
2. Subclassing tf.keras.metrics.Metric
For more complex metrics that need to maintain state between batches, subclass tf.keras.metrics.Metric
:
import tensorflow as tf
class F1Score(tf.keras.metrics.Metric):
def __init__(self, name='f1_score', **kwargs):
super(F1Score, self).__init__(name=name, **kwargs)
# Define state variables
self.precision = tf.keras.metrics.Precision()
self.recall = tf.keras.metrics.Recall()
def update_state(self, y_true, y_pred, sample_weight=None):
# Update precision and recall
y_pred_binary = tf.cast(tf.greater_equal(y_pred, 0.5), tf.float32)
self.precision.update_state(y_true, y_pred_binary, sample_weight)
self.recall.update_state(y_true, y_pred_binary, sample_weight)
def result(self):
# Calculate F1 using precision and recall
precision = self.precision.result()
recall = self.recall.result()
return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))
def reset_state(self):
# Reset all state variables
self.precision.reset_state()
self.recall.reset_state()
# Use in model compilation
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=[F1Score()]
)
This approach is more powerful as it allows you to:
- Maintain state across batches
- Reset state between epochs
- Handle weighted samples
- Create more complex metrics that need to track multiple values
3. Using MeanMetricWrapper
For simple metrics that just need to average a function over batches:
import tensorflow as tf
def mean_absolute_percentage_error(y_true, y_pred):
return 100 * tf.reduce_mean(tf.abs((y_true - y_pred) / (y_true + tf.keras.backend.epsilon())))
# Create a metric from the function
MAPE = tf.keras.metrics.MeanMetricWrapper(
fn=mean_absolute_percentage_error,
name='mape'
)
# Use in model compilation
model.compile(
optimizer='adam',
loss='mse',
metrics=[MAPE]
)
Practical Examples
Let's explore some practical examples of custom metrics for different use cases.
Example 1: Custom Metric for Multi-Label Classification
In multi-label classification, we often need the "subset accuracy" which is only 1 when we predict all labels correctly:
import tensorflow as tf
class SubsetAccuracy(tf.keras.metrics.Metric):
def __init__(self, name='subset_accuracy', **kwargs):
super(SubsetAccuracy, self).__init__(name=name, **kwargs)
self.total = self.add_weight(name='total', initializer='zeros')
self.count = self.add_weight(name='count', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_binary = tf.cast(tf.greater_equal(y_pred, 0.5), tf.float32)
# Check if all predictions match for each sample
all_correct = tf.reduce_all(tf.equal(y_true, y_pred_binary), axis=1)
all_correct = tf.cast(all_correct, tf.float32)
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self.dtype)
all_correct = tf.multiply(all_correct, sample_weight)
self.total.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))
self.count.assign_add(tf.reduce_sum(all_correct))
def result(self):
return self.count / self.total
def reset_state(self):
self.total.assign(0)
self.count.assign(0)
# Example usage
batch_size = 3
num_classes = 4
# Sample data
y_true = tf.constant([[1, 0, 1, 1], [0, 0, 1, 0], [1, 1, 0, 1]], dtype=tf.float32)
y_pred = tf.constant([[0.9, 0.2, 0.8, 0.9], [0.1, 0.2, 0.7, 0.3], [0.9, 0.9, 0.2, 0.8]], dtype=tf.float32)
# Create and use the metric
subset_acc = SubsetAccuracy()
subset_acc.update_state(y_true, y_pred)
print(f"Subset Accuracy: {subset_acc.result().numpy()}")
# Expected output: Subset Accuracy: 0.6666667 (2/3 samples have all labels correct)
Example 2: Custom Regression Metric (R-squared)
R-squared is a common metric for regression models but isn't included in TensorFlow's standard metrics:
import tensorflow as tf
class RSquared(tf.keras.metrics.Metric):
def __init__(self, name='r_squared', **kwargs):
super(RSquared, self).__init__(name=name, **kwargs)
self.ss_total = self.add_weight('ss_total', initializer='zeros')
self.ss_residual = self.add_weight('ss_residual', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
# Flatten the inputs
y_true = tf.reshape(y_true, [-1])
y_pred = tf.reshape(y_pred, [-1])
# Calculate the mean of y_true
y_mean = tf.reduce_mean(y_true)
# Calculate total sum of squares
ss_total_batch = tf.reduce_sum(tf.square(y_true - y_mean))
# Calculate residual sum of squares
ss_residual_batch = tf.reduce_sum(tf.square(y_true - y_pred))
# Update state variables
self.ss_total.assign_add(ss_total_batch)
self.ss_residual.assign_add(ss_residual_batch)
def result(self):
# R² = 1 - (SS_res / SS_total)
# With handling for edge case where ss_total is 0
return tf.maximum(0.0, 1.0 - self.ss_residual / (self.ss_total + tf.keras.backend.epsilon()))
def reset_state(self):
self.ss_total.assign(0)
self.ss_residual.assign(0)
# Example usage
y_true = tf.constant([3, 5, 7, 9, 11], dtype=tf.float32)
y_pred = tf.constant([2.8, 4.7, 7.2, 9.3, 10.5], dtype=tf.float32)
r_squared = RSquared()
r_squared.update_state(y_true, y_pred)
print(f"R-squared: {r_squared.result().numpy()}")
# Expected output: R-squared: 0.97... (high R² indicating good fit)
Example 3: Weighted F1 Score for Imbalanced Datasets
For imbalanced datasets, a weighted F1 score might be more appropriate:
import tensorflow as tf
class WeightedF1Score(tf.keras.metrics.Metric):
def __init__(self, num_classes, name='weighted_f1_score', **kwargs):
super(WeightedF1Score, self).__init__(name=name, **kwargs)
self.num_classes = num_classes
# Initialize true positives, false positives, false negatives for each class
self.tp = self.add_weight(
'tp', shape=(num_classes,), initializer='zeros', dtype=tf.float32)
self.fp = self.add_weight(
'fp', shape=(num_classes,), initializer='zeros', dtype=tf.float32)
self.fn = self.add_weight(
'fn', shape=(num_classes,), initializer='zeros', dtype=tf.float32)
self.class_counts = self.add_weight(
'class_counts', shape=(num_classes,), initializer='zeros', dtype=tf.float32)
def update_state(self, y_true, y_pred, sample_weight=None):
# Convert probabilities to one-hot encoded predictions
y_pred = tf.one_hot(tf.argmax(y_pred, axis=1), depth=self.num_classes)
# One-hot encode y_true if it's not already
if tf.rank(y_true) == 1:
y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=self.num_classes)
# Calculate class counts for weighting
class_counts_batch = tf.reduce_sum(y_true, axis=0)
# Calculate true positives, false positives, false negatives for each class
tp_batch = tf.reduce_sum(y_true * y_pred, axis=0)
fp_batch = tf.reduce_sum((1 - y_true) * y_pred, axis=0)
fn_batch = tf.reduce_sum(y_true * (1 - y_pred), axis=0)
# Update state variables
self.tp.assign_add(tp_batch)
self.fp.assign_add(fp_batch)
self.fn.assign_add(fn_batch)
self.class_counts.assign_add(class_counts_batch)
def result(self):
# Calculate per-class precision and recall
precision = self.tp / (self.tp + self.fp + tf.keras.backend.epsilon())
recall = self.tp / (self.tp + self.fn + tf.keras.backend.epsilon())
# Calculate per-class F1
f1 = 2 * precision * recall / (precision + recall + tf.keras.backend.epsilon())
# Calculate weight for each class
weights = self.class_counts / tf.reduce_sum(self.class_counts)
# Return weighted average F1
return tf.reduce_sum(f1 * weights)
def reset_state(self):
for var in self.variables:
var.assign(tf.zeros_like(var))
# Example usage
num_classes = 3
batch_size = 4
# Generate sample data
y_true = tf.one_hot([0, 1, 2, 1], depth=num_classes)
y_pred = tf.random.normal([batch_size, num_classes])
# Create and use metric
weighted_f1 = WeightedF1Score(num_classes=num_classes)
weighted_f1.update_state(y_true, y_pred)
print(f"Weighted F1 Score: {weighted_f1.result().numpy()}")
Real-World Application: Custom Metric for Medical Image Segmentation
Here's a practical example for medical image segmentation models where the Dice coefficient is a common evaluation metric:
import tensorflow as tf
class DiceCoefficient(tf.keras.metrics.Metric):
def __init__(self, name='dice_coefficient', smooth=1.0, **kwargs):
super(DiceCoefficient, self).__init__(name=name, **kwargs)
self.smooth = smooth
self.dice_sum = self.add_weight(name='dice_sum', initializer='zeros')
self.batch_count = self.add_weight(name='batch_count', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
# Flatten the inputs
y_true_flat = tf.reshape(y_true, [-1])
y_pred_flat = tf.reshape(y_pred, [-1])
# Calculate intersection and dice coefficient
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
dice = (2.0 * intersection + self.smooth) / (
tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + self.smooth
)
# Update state variables
self.dice_sum.assign_add(dice)
self.batch_count.assign_add(1.0)
def result(self):
return self.dice_sum / self.batch_count
def reset_state(self):
self.dice_sum.assign(0.0)
self.batch_count.assign(0.0)
# Example usage in a segmentation model
def create_segmentation_model(input_shape=(128, 128, 1)):
inputs = tf.keras.Input(shape=input_shape)
# Simple U-Net-like structure (simplified)
x = tf.keras.layers.Conv2D(16, 3, activation='relu', padding='same')(inputs)
x = tf.keras.layers.MaxPooling2D(2)(x)
x = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same')(x)
x = tf.keras.layers.Conv2DTranspose(16, 3, strides=2, activation='relu', padding='same')(x)
outputs = tf.keras.layers.Conv2D(1, 1, activation='sigmoid')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', DiceCoefficient()]
)
return model
# Create model with custom metric
model = create_segmentation_model()
print(model.summary())
Best Practices for Custom Metrics
When creating custom metrics, keep these best practices in mind:
-
Efficiency: Optimize your metric computation, especially for large datasets.
-
Numerical Stability: Always add small epsilon values to denominators to avoid division by zero.
-
Batch Independence: Design metrics to handle variable batch sizes correctly.
-
Reset State Properly: Ensure all internal state variables are reset between epochs.
-
Stateless vs. Stateful: Choose the appropriate approach based on your needs.
-
Testing: Verify your custom metric with simple test cases to ensure correctness.
-
Documentation: Comment your code thoroughly to explain the metric's purpose and implementation.
Summary
Custom metrics in TensorFlow provide flexibility to evaluate your models based on domain-specific needs. We've covered:
- Different approaches to creating custom metrics: functions, subclassing
Metric
, and usingMeanMetricWrapper
- Practical examples for different use cases including classification, regression, and image segmentation
- Best practices for implementing efficient and reliable metrics
By creating custom metrics, you can gain deeper insights into your model's performance and better align your evaluation with your project's specific goals.
Additional Resources
- TensorFlow Metrics API Documentation
- Guide to TensorFlow Metrics
- Research Paper: "Metrics for Medical Image Segmentation"
Exercises
- Create a custom metric that implements Matthews Correlation Coefficient (MCC) for binary classification.
- Modify the R-squared implementation to handle weighted samples.
- Implement a custom metric for multi-class object detection that computes the mean Average Precision (mAP).
- Create a custom metric for time series forecasting that weights recent predictions more heavily than older ones.
- Combine multiple metrics into a single custom metric for a specific domain (e.g., combining accuracy and latency for real-time systems).
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)