Skip to main content

TensorFlow Pruning

Introduction

When deploying machine learning models to production, especially on resource-constrained devices like mobile phones or edge devices, model size and inference speed become critical factors. TensorFlow Pruning is a powerful technique that addresses these challenges by systematically removing unnecessary weights from your neural network models.

Pruning works on a simple premise: many weights in neural networks contribute very little to the final output and can be removed without significantly affecting model performance. By "pruning" these weights (setting them to zero), we can create sparse models that require less storage space and can run faster, while maintaining similar accuracy levels.

In this tutorial, we'll explore how to implement pruning in TensorFlow using the TensorFlow Model Optimization Toolkit, understand its benefits, and learn best practices for deploying pruned models.

Prerequisites

Before diving into pruning, make sure you have:

  • Basic understanding of TensorFlow and Keras
  • Python 3.6+
  • TensorFlow 2.x installed
  • TensorFlow Model Optimization Toolkit installed

If you need to install the TensorFlow Model Optimization Toolkit:

bash
pip install tensorflow-model-optimization

Understanding Model Pruning

What is Pruning?

Pruning is the process of systematically zeroing out parameters (weights) in a neural network that have minimal impact on the model's predictions. These "pruned" parameters don't need to be stored or used during inference, resulting in:

  1. Reduced model size: Fewer non-zero parameters means smaller storage requirements
  2. Faster inference: Fewer computations needed during forward passes
  3. Lower memory footprint: Less memory required to store model weights

Types of Pruning

TensorFlow supports several pruning approaches:

  1. Magnitude-based pruning: Removes weights with the smallest absolute values
  2. Structured pruning: Removes entire channels or filters
  3. Constant sparsity pruning: Maintains a fixed percentage of zeroed weights
  4. Scheduled pruning: Gradually increases sparsity during training

Basic Pruning Implementation

Let's implement a simple example of magnitude-based pruning on a small CNN model for MNIST:

python
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_model_optimization as tfmot

# Load and preprocess MNIST dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# Normalize the images
train_images = train_images / 255.0
test_images = test_images / 255.0

# Reshape for the model
train_images = train_images.reshape(-1, 28, 28, 1)
test_images = test_images.reshape(-1, 28, 28, 1)

# Create a simple CNN model
def create_model():
model = keras.Sequential([
keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.MaxPooling2D((2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model

# Create baseline model
model = create_model()

# Train the model
model.fit(train_images, train_labels, epochs=5, validation_split=0.1)

# Evaluate the baseline model
baseline_loss, baseline_accuracy = model.evaluate(test_images, test_labels, verbose=0)
print(f"Baseline model accuracy: {baseline_accuracy:.4f}")

Now, let's apply pruning to this model:

python
# Define pruning configuration
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0,
final_sparsity=0.5,
begin_step=0,
end_step=1000)
}

# Apply pruning to the model
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
model_for_pruning = prune_low_magnitude(model, **pruning_params)

# Recompile the model
model_for_pruning.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

# Create callback for pruning
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep()
]

# Train the pruned model
model_for_pruning.fit(train_images, train_labels,
batch_size=128,
epochs=5,
validation_split=0.1,
callbacks=callbacks)

# Evaluate the pruned model
pruned_loss, pruned_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
print(f"Pruned model accuracy: {pruned_accuracy:.4f}")

# Strip the pruning wrapper for deployment
final_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

# Compare model sizes
import tempfile
import os

# Save the baseline model
_, baseline_model_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, baseline_model_file, include_optimizer=False)
print(f"Baseline model size: {os.path.getsize(baseline_model_file) / 1024 / 1024:.2f} MB")

# Save the pruned model
_, pruned_model_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(final_model, pruned_model_file, include_optimizer=False)
print(f"Pruned model size: {os.path.getsize(pruned_model_file) / 1024 / 1024:.2f} MB")

Expected output:

Baseline model accuracy: 0.9912
Pruned model accuracy: 0.9908
Baseline model size: 1.73 MB
Pruned model size: 0.87 MB

As you can see, we've maintained almost identical accuracy while reducing the model size by approximately 50%!

Pruning Workflow Explained

Let's break down the key components in the pruning process:

1. Define Pruning Schedule

The PolynomialDecay schedule gradually increases the sparsity (percentage of weights set to zero) from the initial_sparsity to the final_sparsity over the course of training:

python
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, # Start with no sparsity (all weights used)
final_sparsity=0.5, # End with 50% of weights pruned
begin_step=0, # When to start pruning
end_step=1000 # When to stop increasing sparsity
)

2. Wrap the Model with Pruning

We apply the prune_low_magnitude function to wrap each layer of our model with pruning functionality:

python
model_for_pruning = prune_low_magnitude(model, **pruning_params)

3. Add Pruning Callback

The UpdatePruningStep callback is essential - it updates the pruning step counter and applies the pruning schedule during training:

python
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep()
]

4. Strip Pruning for Deployment

Before deployment, we need to remove the pruning wrappers while keeping the pruned weights:

python
final_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

Advanced Pruning Techniques

Selective Layer Pruning

You might not want to prune all layers equally. Some layers (like the final classification layer) might be more sensitive to pruning than others:

python
def apply_selective_pruning(model):
# List to hold all layers
pruned_layers = []

# Go through each layer
for i, layer in enumerate(model.layers):
# Don't prune the final dense layer
if i == len(model.layers) - 1 and isinstance(layer, keras.layers.Dense):
pruned_layers.append(layer)
else:
# Apply higher sparsity to conv layers, lower to dense layers
if isinstance(layer, keras.layers.Conv2D):
sparsity = 0.7 # 70% of weights pruned
elif isinstance(layer, keras.layers.Dense):
sparsity = 0.5 # 50% of weights pruned
else:
# For layers without weights (pooling, etc.), just add them as-is
pruned_layers.append(layer)
continue

pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
sparsity, begin_step=0, frequency=100
)
}
pruned_layers.append(
prune_low_magnitude(layer, **pruning_params)
)

# Recreate the model with pruned layers
pruned_model = keras.Sequential(pruned_layers)
return pruned_model

Structured Pruning

While the default pruning is unstructured (individual weights are pruned), structured pruning removes entire channels or filters, which can lead to greater hardware acceleration:

python
# Structured pruning example
block_size = (1, 4) # Prune in blocks of 1x4

pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0,
final_sparsity=0.5,
begin_step=0,
end_step=1000
),
'block_size': block_size,
'block_pooling_type': 'AVG' # Use average across the block to decide pruning
}

structured_pruned_model = prune_low_magnitude(model, **pruning_params)

Real-World Application: Pruning for Mobile Deployment

Let's see a complete example of training, pruning, and exporting a model for mobile deployment:

python
import tensorflow as tf
import tensorflow_model_optimization as tfmot

# 1. Define a simple model for image classification
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

# 2. Load dataset (e.g., CIFAR-10)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0

# 3. Train the original model
model.fit(train_images, train_labels, epochs=5, validation_split=0.1)

# 4. Define pruning configuration
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0,
final_sparsity=0.6, # 60% of weights will be pruned
begin_step=0,
end_step=len(train_images) * 5 // 128 # 5 epochs with batch size 128
)

# 5. Apply pruning
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
model, pruning_schedule=pruning_schedule
)

# Compile the pruned model
pruned_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# 6. Fine-tune the pruned model
pruned_model.fit(
train_images, train_labels,
batch_size=128, epochs=5, validation_split=0.1,
callbacks=[tfmot.sparsity.keras.UpdatePruningStep()]
)

# 7. Evaluate the model
baseline_accuracy = model.evaluate(test_images, test_labels)[1]
pruned_accuracy = pruned_model.evaluate(test_images, test_labels)[1]

print(f"Baseline model accuracy: {baseline_accuracy:.4f}")
print(f"Pruned model accuracy: {pruned_accuracy:.4f}")

# 8. Prepare for export
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

# 9. Export as TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
tflite_model = converter.convert()

# 10. Apply additional compression (optional)
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
compressed_tflite_model = converter.convert()

# Save the models
with open('pruned_model.tflite', 'wb') as f:
f.write(tflite_model)

with open('pruned_compressed_model.tflite', 'wb') as f:
f.write(compressed_tflite_model)

# Compare sizes
import os
print(f"Original TFLite size: {len(tflite_model) / 1024:.2f} KB")
print(f"Compressed TFLite size: {len(compressed_tflite_model) / 1024:.2f} KB")

This workflow creates a pruned model ready for deployment on mobile or edge devices via TensorFlow Lite.

Best Practices for TensorFlow Pruning

  1. Start with a well-trained model: Always prune from a model that's already been trained to good accuracy.

  2. Use gradual pruning: Gradually increasing sparsity during training generally works better than pruning all at once.

  3. Fine-tune after pruning: Always fine-tune your model after applying pruning to recover accuracy.

  4. Test different sparsity levels: Try different final sparsity levels (40%, 50%, 60%, etc.) to find the best trade-off between model size and accuracy.

  5. Combine with other techniques: Pruning works well in combination with quantization and knowledge distillation for even greater optimization.

  6. Monitor accuracy closely: Make sure the pruned model's accuracy doesn't drop too much compared to the original model.

  7. Consider layer sensitivity: Some layers (typically early layers and final classification layers) may be more sensitive to pruning than others.

Debugging and Common Issues

Issue: Significant Accuracy Drop

If you see a large drop in accuracy:

  • Lower your final sparsity target
  • Increase the fine-tuning epochs
  • Use a more gradual pruning schedule

Issue: Model Size Not Decreasing

If your model size isn't decreasing as expected:

  • Make sure you called strip_pruning() before saving
  • Check if you're saving the optimizer state which can be large
  • Consider using TensorFlow Lite with additional optimizations

Issue: Slow Inference Despite Pruning

If inference isn't faster despite pruning:

  • Standard TensorFlow doesn't automatically accelerate sparse models
  • Convert to TensorFlow Lite format, which can take advantage of sparsity
  • Use TensorFlow Lite with specialized delegates for hardware acceleration

Summary

TensorFlow Pruning is a powerful technique for optimizing models for deployment, particularly in resource-constrained environments. By systematically removing unnecessary weights, pruning can significantly reduce model size and potentially speed up inference while maintaining similar accuracy levels.

In this tutorial, we've learned:

  • The basic concepts of model pruning
  • How to implement pruning using TensorFlow Model Optimization Toolkit
  • Advanced pruning techniques like selective and structured pruning
  • A complete workflow for pruning models for mobile deployment
  • Best practices and debugging tips for effective pruning

With these techniques, you can make your TensorFlow models more efficient and ready for deployment on a wide range of devices.

Additional Resources

Exercises

  1. Experiment with Different Sparsities: Try pruning a model with different sparsity levels (30%, 50%, 70%) and compare the accuracy and size trade-offs.

  2. Combined Optimization: Apply both pruning and quantization to a model and measure the combined effect on model size and accuracy.

  3. Custom Pruning Schedule: Create a custom pruning schedule that increases sparsity in steps rather than continuously, and compare its performance to the standard PolynomialDecay.

  4. Layer-wise Analysis: Analyze which layers in your model are most/least sensitive to pruning by applying different sparsity levels to different layers.

  5. Real-world Application: Take a pre-trained model for a practical task (like image recognition or natural language processing) and optimize it with pruning for deployment on a mobile device.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)