TensorFlow Pruning
Introduction
When deploying machine learning models to production, especially on resource-constrained devices like mobile phones or edge devices, model size and inference speed become critical factors. TensorFlow Pruning is a powerful technique that addresses these challenges by systematically removing unnecessary weights from your neural network models.
Pruning works on a simple premise: many weights in neural networks contribute very little to the final output and can be removed without significantly affecting model performance. By "pruning" these weights (setting them to zero), we can create sparse models that require less storage space and can run faster, while maintaining similar accuracy levels.
In this tutorial, we'll explore how to implement pruning in TensorFlow using the TensorFlow Model Optimization Toolkit, understand its benefits, and learn best practices for deploying pruned models.
Prerequisites
Before diving into pruning, make sure you have:
- Basic understanding of TensorFlow and Keras
- Python 3.6+
- TensorFlow 2.x installed
- TensorFlow Model Optimization Toolkit installed
If you need to install the TensorFlow Model Optimization Toolkit:
pip install tensorflow-model-optimization
Understanding Model Pruning
What is Pruning?
Pruning is the process of systematically zeroing out parameters (weights) in a neural network that have minimal impact on the model's predictions. These "pruned" parameters don't need to be stored or used during inference, resulting in:
- Reduced model size: Fewer non-zero parameters means smaller storage requirements
- Faster inference: Fewer computations needed during forward passes
- Lower memory footprint: Less memory required to store model weights
Types of Pruning
TensorFlow supports several pruning approaches:
- Magnitude-based pruning: Removes weights with the smallest absolute values
- Structured pruning: Removes entire channels or filters
- Constant sparsity pruning: Maintains a fixed percentage of zeroed weights
- Scheduled pruning: Gradually increases sparsity during training
Basic Pruning Implementation
Let's implement a simple example of magnitude-based pruning on a small CNN model for MNIST:
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_model_optimization as tfmot
# Load and preprocess MNIST dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
# Normalize the images
train_images = train_images / 255.0
test_images = test_images / 255.0
# Reshape for the model
train_images = train_images.reshape(-1, 28, 28, 1)
test_images = test_images.reshape(-1, 28, 28, 1)
# Create a simple CNN model
def create_model():
model = keras.Sequential([
keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# Create baseline model
model = create_model()
# Train the model
model.fit(train_images, train_labels, epochs=5, validation_split=0.1)
# Evaluate the baseline model
baseline_loss, baseline_accuracy = model.evaluate(test_images, test_labels, verbose=0)
print(f"Baseline model accuracy: {baseline_accuracy:.4f}")
Now, let's apply pruning to this model:
# Define pruning configuration
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0,
final_sparsity=0.5,
begin_step=0,
end_step=1000)
}
# Apply pruning to the model
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
model_for_pruning = prune_low_magnitude(model, **pruning_params)
# Recompile the model
model_for_pruning.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Create callback for pruning
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep()
]
# Train the pruned model
model_for_pruning.fit(train_images, train_labels,
batch_size=128,
epochs=5,
validation_split=0.1,
callbacks=callbacks)
# Evaluate the pruned model
pruned_loss, pruned_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
print(f"Pruned model accuracy: {pruned_accuracy:.4f}")
# Strip the pruning wrapper for deployment
final_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
# Compare model sizes
import tempfile
import os
# Save the baseline model
_, baseline_model_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, baseline_model_file, include_optimizer=False)
print(f"Baseline model size: {os.path.getsize(baseline_model_file) / 1024 / 1024:.2f} MB")
# Save the pruned model
_, pruned_model_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(final_model, pruned_model_file, include_optimizer=False)
print(f"Pruned model size: {os.path.getsize(pruned_model_file) / 1024 / 1024:.2f} MB")
Expected output:
Baseline model accuracy: 0.9912
Pruned model accuracy: 0.9908
Baseline model size: 1.73 MB
Pruned model size: 0.87 MB
As you can see, we've maintained almost identical accuracy while reducing the model size by approximately 50%!
Pruning Workflow Explained
Let's break down the key components in the pruning process:
1. Define Pruning Schedule
The PolynomialDecay
schedule gradually increases the sparsity (percentage of weights set to zero) from the initial_sparsity
to the final_sparsity
over the course of training:
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, # Start with no sparsity (all weights used)
final_sparsity=0.5, # End with 50% of weights pruned
begin_step=0, # When to start pruning
end_step=1000 # When to stop increasing sparsity
)
2. Wrap the Model with Pruning
We apply the prune_low_magnitude
function to wrap each layer of our model with pruning functionality:
model_for_pruning = prune_low_magnitude(model, **pruning_params)
3. Add Pruning Callback
The UpdatePruningStep
callback is essential - it updates the pruning step counter and applies the pruning schedule during training:
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep()
]
4. Strip Pruning for Deployment
Before deployment, we need to remove the pruning wrappers while keeping the pruned weights:
final_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
Advanced Pruning Techniques
Selective Layer Pruning
You might not want to prune all layers equally. Some layers (like the final classification layer) might be more sensitive to pruning than others:
def apply_selective_pruning(model):
# List to hold all layers
pruned_layers = []
# Go through each layer
for i, layer in enumerate(model.layers):
# Don't prune the final dense layer
if i == len(model.layers) - 1 and isinstance(layer, keras.layers.Dense):
pruned_layers.append(layer)
else:
# Apply higher sparsity to conv layers, lower to dense layers
if isinstance(layer, keras.layers.Conv2D):
sparsity = 0.7 # 70% of weights pruned
elif isinstance(layer, keras.layers.Dense):
sparsity = 0.5 # 50% of weights pruned
else:
# For layers without weights (pooling, etc.), just add them as-is
pruned_layers.append(layer)
continue
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
sparsity, begin_step=0, frequency=100
)
}
pruned_layers.append(
prune_low_magnitude(layer, **pruning_params)
)
# Recreate the model with pruned layers
pruned_model = keras.Sequential(pruned_layers)
return pruned_model
Structured Pruning
While the default pruning is unstructured (individual weights are pruned), structured pruning removes entire channels or filters, which can lead to greater hardware acceleration:
# Structured pruning example
block_size = (1, 4) # Prune in blocks of 1x4
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0,
final_sparsity=0.5,
begin_step=0,
end_step=1000
),
'block_size': block_size,
'block_pooling_type': 'AVG' # Use average across the block to decide pruning
}
structured_pruned_model = prune_low_magnitude(model, **pruning_params)
Real-World Application: Pruning for Mobile Deployment
Let's see a complete example of training, pruning, and exporting a model for mobile deployment:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
# 1. Define a simple model for image classification
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 2. Load dataset (e.g., CIFAR-10)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
# 3. Train the original model
model.fit(train_images, train_labels, epochs=5, validation_split=0.1)
# 4. Define pruning configuration
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0,
final_sparsity=0.6, # 60% of weights will be pruned
begin_step=0,
end_step=len(train_images) * 5 // 128 # 5 epochs with batch size 128
)
# 5. Apply pruning
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
model, pruning_schedule=pruning_schedule
)
# Compile the pruned model
pruned_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# 6. Fine-tune the pruned model
pruned_model.fit(
train_images, train_labels,
batch_size=128, epochs=5, validation_split=0.1,
callbacks=[tfmot.sparsity.keras.UpdatePruningStep()]
)
# 7. Evaluate the model
baseline_accuracy = model.evaluate(test_images, test_labels)[1]
pruned_accuracy = pruned_model.evaluate(test_images, test_labels)[1]
print(f"Baseline model accuracy: {baseline_accuracy:.4f}")
print(f"Pruned model accuracy: {pruned_accuracy:.4f}")
# 8. Prepare for export
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
# 9. Export as TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
tflite_model = converter.convert()
# 10. Apply additional compression (optional)
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
compressed_tflite_model = converter.convert()
# Save the models
with open('pruned_model.tflite', 'wb') as f:
f.write(tflite_model)
with open('pruned_compressed_model.tflite', 'wb') as f:
f.write(compressed_tflite_model)
# Compare sizes
import os
print(f"Original TFLite size: {len(tflite_model) / 1024:.2f} KB")
print(f"Compressed TFLite size: {len(compressed_tflite_model) / 1024:.2f} KB")
This workflow creates a pruned model ready for deployment on mobile or edge devices via TensorFlow Lite.
Best Practices for TensorFlow Pruning
-
Start with a well-trained model: Always prune from a model that's already been trained to good accuracy.
-
Use gradual pruning: Gradually increasing sparsity during training generally works better than pruning all at once.
-
Fine-tune after pruning: Always fine-tune your model after applying pruning to recover accuracy.
-
Test different sparsity levels: Try different final sparsity levels (40%, 50%, 60%, etc.) to find the best trade-off between model size and accuracy.
-
Combine with other techniques: Pruning works well in combination with quantization and knowledge distillation for even greater optimization.
-
Monitor accuracy closely: Make sure the pruned model's accuracy doesn't drop too much compared to the original model.
-
Consider layer sensitivity: Some layers (typically early layers and final classification layers) may be more sensitive to pruning than others.
Debugging and Common Issues
Issue: Significant Accuracy Drop
If you see a large drop in accuracy:
- Lower your final sparsity target
- Increase the fine-tuning epochs
- Use a more gradual pruning schedule
Issue: Model Size Not Decreasing
If your model size isn't decreasing as expected:
- Make sure you called
strip_pruning()
before saving - Check if you're saving the optimizer state which can be large
- Consider using TensorFlow Lite with additional optimizations
Issue: Slow Inference Despite Pruning
If inference isn't faster despite pruning:
- Standard TensorFlow doesn't automatically accelerate sparse models
- Convert to TensorFlow Lite format, which can take advantage of sparsity
- Use TensorFlow Lite with specialized delegates for hardware acceleration
Summary
TensorFlow Pruning is a powerful technique for optimizing models for deployment, particularly in resource-constrained environments. By systematically removing unnecessary weights, pruning can significantly reduce model size and potentially speed up inference while maintaining similar accuracy levels.
In this tutorial, we've learned:
- The basic concepts of model pruning
- How to implement pruning using TensorFlow Model Optimization Toolkit
- Advanced pruning techniques like selective and structured pruning
- A complete workflow for pruning models for mobile deployment
- Best practices and debugging tips for effective pruning
With these techniques, you can make your TensorFlow models more efficient and ready for deployment on a wide range of devices.
Additional Resources
- TensorFlow Model Optimization Toolkit Documentation
- Pruning Comprehensive Guide
- TensorFlow Lite for Mobile Deployment
Exercises
-
Experiment with Different Sparsities: Try pruning a model with different sparsity levels (30%, 50%, 70%) and compare the accuracy and size trade-offs.
-
Combined Optimization: Apply both pruning and quantization to a model and measure the combined effect on model size and accuracy.
-
Custom Pruning Schedule: Create a custom pruning schedule that increases sparsity in steps rather than continuously, and compare its performance to the standard
PolynomialDecay
. -
Layer-wise Analysis: Analyze which layers in your model are most/least sensitive to pruning by applying different sparsity levels to different layers.
-
Real-world Application: Take a pre-trained model for a practical task (like image recognition or natural language processing) and optimize it with pruning for deployment on a mobile device.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)