PyTorch Pruning
Neural network pruning is an essential technique for model optimization that helps reduce model size, improve inference speed, and minimize memory footprint. In this guide, we'll explore how to effectively implement pruning techniques in PyTorch.
What is Network Pruning?
Network pruning is the process of systematically removing weights (connections) or entire neurons/filters from a neural network to create a more compact model. Think of it as trimming away the "fat" from an oversized neural network while preserving its core functionality.
The basic premise is simple: most neural networks are overparameterized, meaning they contain more parameters than necessary to solve their target task. Pruning identifies and removes the least important parameters, resulting in:
- Smaller model size
- Faster inference time
- Lower memory requirements
- Reduced energy consumption
Types of Pruning in PyTorch
PyTorch supports several types of pruning techniques:
- Unstructured pruning: Removes individual weights regardless of their position
- Structured pruning: Removes entire structures like neurons or channels
- Local pruning: Applied to specific layers or modules
- Global pruning: Applied across the entire model
Getting Started with Pruning in PyTorch
PyTorch provides pruning capabilities through the torch.nn.utils.prune
module. Let's start with a simple example of unstructured pruning:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
# Create a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Instantiate the model
model = SimpleModel()
# Print parameter statistics before pruning
print("Before pruning:")
print(f"Parameters in fc1: {model.fc1.weight.nelement()}")
# Apply 30% unstructured L1 pruning to the first layer
prune.l1_unstructured(model.fc1, name='weight', amount=0.3)
# Check the effect of pruning
print("\nAfter pruning:")
print(f"Parameters in fc1: {model.fc1.weight.nelement()}")
print(f"Non-zero parameters in fc1: {torch.sum(model.fc1.weight != 0).item()}")
print(f"Sparsity in fc1: {100 * (1 - torch.sum(model.fc1.weight != 0).item() / model.fc1.weight.nelement()):.2f}%")
Output:
Before pruning:
Parameters in fc1: 100352
After pruning:
Parameters in fc1: 100352
Non-zero parameters in fc1: 70246
Sparsity in fc1: 30.00%
This example shows how to apply L1 unstructured pruning to remove 30% of the weights in the first fully-connected layer, resulting in a model with the same dimensions but increased sparsity.
Pruning Techniques in PyTorch
1. Unstructured Pruning
Unstructured pruning removes individual weights based on their importance. PyTorch offers several methods:
# L1 pruning (prune based on absolute value)
prune.l1_unstructured(module, name, amount)
# Random pruning
prune.random_unstructured(module, name, amount)
# L2 pruning (prune based on squared value)
prune.ln_structured(module, name, amount, n=2, dim=0)
2. Structured Pruning
Structured pruning removes entire groups of parameters:
# Remove channels/neurons
prune.ln_structured(model.conv1, name="weight", amount=0.2, n=2, dim=0) # prune output channels
prune.ln_structured(model.conv1, name="weight", amount=0.2, n=2, dim=1) # prune input channels
3. Global Pruning
To apply pruning across the entire model:
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.3,
)
Iterative Pruning: A Practical Approach
In practice, networks are often pruned iteratively - train, prune, retrain, repeat. Let's implement a simple iterative pruning workflow:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
# Define a simple CNN
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Instantiate model, loss, optimizer
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Define layers to be pruned
modules_to_prune = [(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight')]
# Dummy train function (replace with your actual training loop)
def train_one_epoch(model, optimizer, criterion, train_loader):
# Training code would go here
pass
# Dummy evaluate function (replace with your actual evaluation)
def evaluate(model, test_loader):
# Evaluation code would go here
return 0.95 # Dummy accuracy
# Function to compute model sparsity
def compute_sparsity(model):
total_params = 0
nonzero_params = 0
for name, param in model.named_parameters():
if 'weight' in name:
total_params += param.numel()
nonzero_params += torch.sum(param != 0).item()
return 100 * (1 - nonzero_params / total_params)
# Iterative pruning with fine-tuning
prune_iterations = 5
amount_per_iteration = 0.2 # 20% of remaining weights each time
train_loader = None # Replace with your data loader
test_loader = None # Replace with your test data loader
print("Initial model accuracy and sparsity:")
accuracy = evaluate(model, test_loader)
sparsity = compute_sparsity(model)
print(f"Accuracy: {accuracy:.4f}, Sparsity: {sparsity:.2f}%")
for i in range(prune_iterations):
print(f"\nPruning iteration {i+1}/{prune_iterations}")
# Apply pruning
prune.global_unstructured(
modules_to_prune,
pruning_method=prune.L1Unstructured,
amount=amount_per_iteration,
)
# Fine-tune the pruned model
for epoch in range(3): # Few epochs of fine-tuning
train_one_epoch(model, optimizer, criterion, train_loader)
# Evaluate the pruned model
accuracy = evaluate(model, test_loader)
sparsity = compute_sparsity(model)
print(f"Accuracy: {accuracy:.4f}, Sparsity: {sparsity:.2f}%")
This example demonstrates iterative pruning, where we:
- Train a baseline model
- Repeatedly prune and fine-tune
- Monitor accuracy and sparsity trade-offs
Permanent Pruning
By default, PyTorch maintains pruning masks separately from the weights. To make pruning permanent:
# Apply pruning
prune.l1_unstructured(model.fc1, name='weight', amount=0.3)
# Make it permanent
prune.remove(model.fc1, 'weight')
Making pruning permanent is essential when you plan to export your model or achieve the actual memory benefits.
Real-World Example: MobileNet Pruning
Let's look at a real-world example of pruning a MobileNet model for improved deployment efficiency:
import torch
import torchvision.models as models
import torch.nn.utils.prune as prune
# Load a pre-trained MobileNetV2
model = models.mobilenet_v2(weights='DEFAULT')
# Identify convolution layers for pruning
conv_modules = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
conv_modules.append((module, 'weight'))
# Apply global pruning to all convolution layers
prune.global_unstructured(
conv_modules,
pruning_method=prune.L1Unstructured,
amount=0.2, # Prune 20% of weights
)
# Check overall model sparsity
total_params = 0
zero_params = 0
for name, param in model.named_parameters():
if 'weight' in name:
total_params += param.numel()
zero_params += torch.sum(param == 0).item()
print(f"Model sparsity: {100 * zero_params / total_params:.2f}%")
# Test inference speed (optional)
sample_input = torch.randn(1, 3, 224, 224) # Sample input image
with torch.no_grad():
# Measure inference time
import time
start_time = time.time()
for _ in range(100): # Run 100 inferences for reliable timing
_ = model(sample_input)
end_time = time.time()
print(f"Average inference time: {(end_time - start_time) / 100 * 1000:.2f} ms")
Considerations and Best Practices
-
Start with a well-trained model: Always prune from a fully trained model for best results.
-
Iterative pruning works better: Gradually prune and fine-tune rather than removing a large percentage all at once.
-
Layer sensitivity: Different layers have different sensitivity to pruning. Generally:
- Early layers are more sensitive than later layers
- Convolutional layers are more sensitive than fully-connected layers
-
Monitor accuracy: Always check the accuracy impact of pruning to ensure an acceptable trade-off.
-
Combine with quantization: For maximum compression, combine pruning with quantization techniques.
-
Structured vs. unstructured: Unstructured pruning gives better theoretical compression, but structured pruning often gives better practical speedups on common hardware.
Pruning and Model Export
After pruning, you'll likely want to export your model for deployment. Here's how to export a pruned model to a production-ready format:
# Make pruning permanent before export
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
try:
prune.remove(module, 'weight')
except:
pass # Skip if the module wasn't pruned
# Export to TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save("pruned_model.pt")
Summary
Pruning is a powerful technique for making PyTorch models smaller and faster without significantly impacting accuracy. We've covered:
- Basic concepts of pruning and its benefits
- Different pruning techniques in PyTorch (unstructured, structured, global)
- Implementing iterative pruning workflows
- Making pruning permanent for deployment
- Best practices for effective pruning
By applying these techniques, you can optimize your deep learning models to run efficiently on resource-constrained platforms like mobile devices, edge hardware, or in production systems where latency matters.
Additional Resources and Exercises
Resources
- PyTorch Pruning Documentation
- Research paper: "To prune, or not to prune: exploring the efficacy of pruning for model compression"
- Research paper: "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks"
Exercises
-
Baseline Comparison: Train a simple CNN on MNIST, then create pruned versions with 50%, 70%, and 90% sparsity. Compare their size, inference speed, and accuracy.
-
Sensitivity Analysis: Apply different pruning ratios to different layers of a network and analyze which layers are most sensitive to pruning.
-
Pruning Schedule Implementation: Implement a pruning schedule that gradually increases sparsity during training, rather than pruning all at once.
-
Combined Compression: Combine pruning with quantization to create an ultra-efficient model that maintains at least 95% of the original accuracy.
-
Structured vs. Unstructured: Compare the actual speed improvements between a model with 80% unstructured pruning versus one with 50% structured pruning.
By mastering pruning techniques, you'll be able to deploy more efficient models without sacrificing the performance that makes deep learning so powerful.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)