PyTorch Fusion Optimization
Introduction
When training and deploying deep learning models, performance optimization becomes crucial, especially when working with large-scale models or limited computational resources. One of the most powerful optimization techniques in PyTorch is fusion optimization. Fusion refers to combining multiple operations into a single optimized kernel, reducing memory overhead and computation time.
In this tutorial, we'll explore various fusion optimization techniques in PyTorch, understand how they work under the hood, and learn how to apply them to speed up your models significantly.
What is Fusion Optimization?
Fusion optimization is the process of combining multiple operations that would normally execute separately into a single, optimized computation. This delivers performance benefits through:
- Reduced memory transfers - Fewer read/write operations to GPU memory
- Decreased kernel launch overhead - Fewer individual CUDA kernel calls
- Better cache utilization - Data remains in cache across operations
- Enhanced compiler optimizations - More opportunities for optimization at the hardware level
Let's dive into the different types of fusion optimizations available in PyTorch.
Vertical Fusion
Vertical fusion combines operations that are executed sequentially in the computational graph. A common example is fusing a convolution operation with a subsequent activation function.
Example: Fusing Convolution and ReLU
# Without fusion
def without_fusion(x, conv_layer):
x = conv_layer(x)
x = torch.relu(x)
return x
# With fusion
def with_fusion(x, conv_layer):
# PyTorch can automatically fuse these operations
return torch.nn.functional.relu(conv_layer(x))
While this looks like a simple change, the underlying execution is quite different. In the fused version, PyTorch can generate a single optimized kernel that performs both operations without storing the intermediate result to memory.
JIT Fusion with TorchScript
PyTorch's JIT (Just-In-Time) compiler can automatically detect and apply fusion optimizations when you use TorchScript. Let's see how to enable this:
import torch
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(100, 100)
self.linear2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
x = self.linear2(x)
return torch.log_softmax(x, dim=1)
# Create model and input
model = SimpleModel()
example_input = torch.rand(32, 100)
# Create a scripted version with fusion optimization
scripted_model = torch.jit.script(model)
# Compare execution time
def benchmark(model, input_data, iterations=1000):
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(10):
_ = model(input_data)
# Benchmark
start_time.record()
for _ in range(iterations):
_ = model(input_data)
end_time.record()
torch.cuda.synchronize()
return start_time.elapsed_time(end_time) / iterations
# Move to GPU for benchmarking
if torch.cuda.is_available():
model = model.cuda()
scripted_model = scripted_model.cuda()
example_input = example_input.cuda()
print(f"Original model: {benchmark(model, example_input):.3f} ms per iteration")
print(f"Scripted model: {benchmark(scripted_model, example_input):.3f} ms per iteration")
Output (example):
Original model: 0.142 ms per iteration
Scripted model: 0.087 ms per iteration
The scripted model with fusion optimization runs significantly faster because TorchScript has identified opportunities to fuse operations.
Horizontal Fusion with torch.nn.utils.fusion
PyTorch also supports horizontal fusion, which combines similar operations that can be executed in parallel. A common example is fusing multiple convolutions with the same input dimensions.
Let's see how to use the fusion utilities:
import torch
import torch.nn.utils.fusion as fusion
class MultiConvModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv3 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
def forward(self, x):
# Three separate convolutions on the same input
y1 = self.conv1(x)
y2 = self.conv2(x)
y3 = self.conv3(x)
return y1 + y2 + y3
# Create a regular model
model = MultiConvModel()
# Apply horizontal fusion optimization
# Note: This is a conceptual example - actual implementation details may vary
fused_model = fusion.fuse_modules(model, [['conv1', 'conv2', 'conv3']])
# The fused model would run faster as it batches the convolution operations
Custom Fusion with CUDA Graphs
For advanced optimization scenarios, PyTorch supports CUDA Graphs, which allow for capturing and replaying entire sequences of GPU operations, eliminating overhead from kernel launches.
import torch
def run_model_with_cuda_graph(model, input_tensor):
# Ensure we're on CUDA
if not torch.cuda.is_available():
return model(input_tensor)
# Move to GPU if needed
if input_tensor.device.type != 'cuda':
input_tensor = input_tensor.cuda()
if next(model.parameters()).device.type != 'cuda':
model = model.cuda()
# Warmup
static_input = input_tensor.clone()
for _ in range(3):
output = model(static_input)
# CUDA Graph capturing
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
output = model(static_input)
# Function to run the graph
def run_graph(new_input):
static_input.copy_(new_input)
g.replay()
return output.clone()
return run_graph
# Example usage
model = SimpleModel().cuda()
example_input = torch.rand(32, 100, device='cuda')
# Create the optimized function
optimized_model = run_model_with_cuda_graph(model, example_input)
# Run the optimized model
result = optimized_model(example_input)
CUDA Graphs provide the highest level of optimization by capturing entire computation sequences and eliminating the overheads between individual operations. This is particularly useful for models with fixed input shapes that are run repeatedly.
Real-World Example: Optimizing a ResNet Block
Let's apply fusion optimization to a real-world example - a ResNet block:
import torch
import torch.nn as nn
import time
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
# Create a standard ResNet block
regular_block = ResidualBlock(64, 64).cuda()
# Use TorchScript to create an optimized version
scripted_block = torch.jit.script(regular_block)
# Further optimize with FusedOptimizer (conceptual - actual implementation would vary)
# torch._C._jit_pass_fuse_addmm(scripted_block.graph)
# Create test data
batch_size = 64
test_input = torch.randn(batch_size, 64, 56, 56, device='cuda')
# Benchmark function
def benchmark_model(model, input_data, iterations=100):
# Warmup
for _ in range(10):
_ = model(input_data)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(iterations):
_ = model(input_data)
torch.cuda.synchronize()
end_time = time.time()
return (end_time - start_time) / iterations * 1000 # Convert to ms
# Run benchmarks
regular_time = benchmark_model(regular_block, test_input)
scripted_time = benchmark_model(scripted_block, test_input)
print(f"Regular block: {regular_time:.3f} ms per batch")
print(f"Scripted block: {scripted_time:.3f} ms per batch")
print(f"Speedup: {regular_time/scripted_time:.2f}x")
Output (example):
Regular block: 0.875 ms per batch
Scripted block: 0.612 ms per batch
Speedup: 1.43x
FX Graph Mode Fusion
PyTorch's FX Graph Mode provides another powerful way to apply fusion optimizations. It works by symbolically tracing your model to generate a graph representation, which can then be transformed:
import torch
from torch.fx import symbolic_trace
class SimpleModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
x = self.linear(x)
x = torch.relu(x)
return x
# Create a module
module = SimpleModule()
# Symbolically trace the module
traced_module = symbolic_trace(module)
# Print the graph - you'll see operations that can be fused
print(traced_module.graph)
# A custom fusion pass (conceptual example)
def fuse_linear_relu(gm):
for node in gm.graph.nodes:
if node.op == 'call_module' and node.target == 'linear':
# Find the next node
next_node = list(node.users.keys())[0]
if next_node.op == 'call_function' and next_node.target == torch.relu:
# Create a fused operation
# This is simplified; real implementation would be more complex
print("Found linear+relu pattern that can be fused!")
return gm
# Apply the fusion pass
optimized_module = fuse_linear_relu(traced_module)
FX Graph Mode is particularly powerful for creating custom optimization passes tailored to your specific model architecture.
Best Practices for Fusion Optimization
To get the most out of fusion optimization in PyTorch, follow these best practices:
-
Use TorchScript for automatic fusion optimizations
-
Prefer inplace operations where possible:
python# Less efficient
x = torch.relu(x)
# More efficient - enables better fusion
x = torch.relu_(x) # Or use nn.ReLU(inplace=True) -
Chain operations without storing intermediates:
python# Less fusion-friendly
x = layer1(x)
x = activation(x)
x = layer2(x)
# More fusion-friendly
x = layer2(activation(layer1(x))) -
Use fusion-aware modules like
torch.nn.intrinsic.ConvBnReLU2d
which are specifically designed for fused operations -
Consider input shapes - fusion works best with consistent tensor dimensions
-
Profile before and after - always benchmark to ensure fusion is actually helping
Diagnosing Fusion Issues
Sometimes fusion might not work as expected. Here's how to diagnose common issues:
import torch
# Enable profiling to see if operations are being fused
with torch.autograd.profiler.profile(use_cuda=True) as prof:
# Run your model here
result = model(input_tensor)
# Print the profiler output to see kernel calls
print(prof.key_averages().table(sort_by="cuda_time_total"))
# Look for separate operations that could potentially be fused
# If you see separate kernel calls for operations that could be
# fused, you might need to restructure your code
Summary
PyTorch fusion optimization is a powerful technique for improving the performance of your deep learning models by combining multiple operations into optimized kernels. We've covered:
- Vertical fusion of sequential operations
- Horizontal fusion of parallel operations
- Using TorchScript for automatic fusion
- CUDA Graphs for capturing entire computation sequences
- FX Graph Mode for custom optimization passes
- Best practices for fusion-friendly code
By applying these techniques, you can significantly speed up both training and inference for your PyTorch models, especially on GPU hardware.
Additional Resources
- PyTorch TorchScript Documentation
- PyTorch FX Documentation
- CUDA Graphs in PyTorch
- PyTorch Performance Tuning Guide
Exercises
- Profile a simple neural network with and without TorchScript to measure the speedup from fusion optimization.
- Implement a custom fusion pass using FX Graph Mode to fuse a specific pattern in your model.
- Compare the performance of manually fused operations versus letting TorchScript handle fusion automatically.
- Apply CUDA Graphs to a model with fixed input dimensions and measure the performance improvement.
- Use the PyTorch profiler to identify operations in your model that could benefit from fusion optimization.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)