PyTorch Model Parallelism
When working with deep learning models, you might encounter situations where your model is too large to fit into the memory of a single GPU. This is where model parallelism comes into play. Unlike data parallelism, which distributes batches of data across multiple devices, model parallelism splits the model itself across multiple devices.
What is Model Parallelism?
Model parallelism is a distributed training technique that divides a neural network model across multiple computing devices (usually GPUs). This approach is useful when:
- Your model is too large to fit into a single GPU's memory
- The model consists of components that can be naturally separated
- You want to optimize the use of specialized hardware for different parts of your model
Why Use Model Parallelism?
As deep learning models grow in size and complexity (like GPT, BERT, etc.), they may not fit into the memory of a single GPU. For example, models with billions of parameters require model parallelism to train efficiently.
Basic Model Parallelism in PyTorch
Let's start with a simple example where we manually split a model across two GPUs:
import torch
import torch.nn as nn
class ModelParallelExample(nn.Module):
def __init__(self):
super(ModelParallelExample, self).__init__()
# First part of the model on GPU 0
self.part1 = nn.Sequential(
nn.Linear(20, 500),
nn.ReLU(),
nn.Linear(500, 500),
nn.ReLU()
).to('cuda:0')
# Second part of the model on GPU 1
self.part2 = nn.Sequential(
nn.Linear(500, 500),
nn.ReLU(),
nn.Linear(500, 10)
).to('cuda:1')
def forward(self, x):
# Input on CPU or GPU 0
x = x.to('cuda:0')
# Forward through part1
x = self.part1(x)
# Transfer to GPU 1
x = x.to('cuda:1')
# Forward through part2
return self.part2(x)
# Create model instance
model = ModelParallelExample()
# Sample input
input_tensor = torch.randn(32, 20) # Batch size 32, input size 20
output = model(input_tensor)
print(f"Output shape: {output.shape}")
print(f"Output device: {output.device}")
Output:
Output shape: torch.Size([32, 10])
Output device: cuda:1
In this example, we:
- Split our model into two parts
- Placed each part on a different GPU
- Handled the transfer of data between GPUs in the forward pass
Note on Memory Transfer Overhead
While this approach works, it introduces overhead from transferring intermediate outputs between GPUs. This can sometimes outweigh the benefits of using multiple GPUs, especially for small models or fast operations.
Pipeline Parallelism
For more efficient model parallelism, PyTorch offers pipeline parallelism which helps reduce memory transfer overhead by processing mini-batches in a pipelined fashion.
Here's a simplified example using PyTorch's nn.Sequential
and manually implementing pipeline stages:
import torch
import torch.nn as nn
class PipelineParallelModule(nn.Module):
def __init__(self, split_size=32):
super(PipelineParallelModule, self).__init__()
# Networks on different devices
self.stage1 = nn.Sequential(
nn.Linear(20, 500),
nn.ReLU()
).to('cuda:0')
self.stage2 = nn.Sequential(
nn.Linear(500, 500),
nn.ReLU()
).to('cuda:1')
self.stage3 = nn.Sequential(
nn.Linear(500, 10)
).to('cuda:0')
self.split_size = split_size
def forward(self, x):
# Split input into chunks for pipeline parallelism
chunks = x.split(self.split_size)
output_chunks = []
for chunk in chunks:
# Process through stage 1 (GPU 0)
out = self.stage1(chunk.to('cuda:0'))
# Process through stage 2 (GPU 1)
out = self.stage2(out.to('cuda:1'))
# Process through stage 3 (GPU 0)
out = self.stage3(out.to('cuda:0'))
output_chunks.append(out)
# Concatenate results
return torch.cat(output_chunks, dim=0)
# Create model
model = PipelineParallelModule()
# Sample input (batch size 128)
input_tensor = torch.randn(128, 20)
output = model(input_tensor)
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output.shape}")
print(f"Output device: {output.device}")
Output:
Input shape: torch.Size([128, 20])
Output shape: torch.Size([128, 10])
Output device: cuda:0
This approach processes data in chunks, reducing peak memory usage but still introducing some overhead from device transfers.
Using PyTorch's Built-in Tools for Model Parallelism
For production use cases, PyTorch provides more sophisticated tools for model parallelism:
1. Tensor Parallelism with torch.nn.parallel.DistributedDataParallel
Let's implement tensor parallelism using PyTorch's distributed package:
import os
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
class LargeModel(nn.Module):
def __init__(self):
super(LargeModel, self).__init__()
# Creating a large layer that can be split
self.large_layer = nn.Linear(1000, 1000)
def forward(self, x):
return self.large_layer(x)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# Initialize process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def run_model(rank, world_size):
print(f"Running on rank {rank}")
setup(rank, world_size)
# Create model and move to GPU
model = LargeModel().to(rank)
# Wrap model with DistributedDataParallel
ddp_model = DistributedDataParallel(model, device_ids=[rank])
# Create input tensor
input_tensor = torch.randn(20, 1000).to(rank)
# Forward pass
output = ddp_model(input_tensor)
print(f"Rank: {rank}, Output shape: {output.shape}")
cleanup()
def main():
world_size = torch.cuda.device_count()
mp.spawn(run_model,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
main()
2. Using PyTorch's nn.parallel.DistributedTensorParallel
(New API)
In newer PyTorch versions, TensorParallel (TP) is available for splitting individual layers across GPUs:
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.tensor.parallel import RowwiseParallel
# This code requires PyTorch 2.0+ with proper distributed setup
def tensor_parallel_example():
# Initialize distributed environment (assuming already set up)
rank = dist.get_rank()
world_size = dist.get_world_size()
# Create a large layer
fc = nn.Linear(2048, 2048)
# Wrap with tensor parallelism
tp_fc = RowwiseParallel(fc)
# Create input
inp = torch.randn(32, 2048).cuda()
# Forward pass
output = tp_fc(inp)
print(f"Output shape: {output.shape}")
return output
Practical Use Case: Large Language Model Training
Let's see how to apply model parallelism to train a simple but large language model:
import torch
import torch.nn as nn
class LargeLanguageModel(nn.Module):
def __init__(self, vocab_size=50000, embed_dim=2048, num_heads=16, num_layers=24):
super().__init__()
# Token embeddings on GPU 0
self.token_embedding = nn.Embedding(vocab_size, embed_dim).to('cuda:0')
# First half of transformer layers on GPU 0
self.layers_first_half = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
for _ in range(num_layers // 2)
]).to('cuda:0')
# Second half of transformer layers on GPU 1
self.layers_second_half = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
for _ in range(num_layers // 2)
]).to('cuda:1')
# Output layer on GPU 1
self.output_layer = nn.Linear(embed_dim, vocab_size).to('cuda:1')
def forward(self, x):
# Embedding on GPU 0
x = x.to('cuda:0')
x = self.token_embedding(x)
# First half processing on GPU 0
for layer in self.layers_first_half:
x = layer(x)
# Transfer to GPU 1
x = x.to('cuda:1')
# Second half processing on GPU 1
for layer in self.layers_second_half:
x = layer(x)
# Output projection on GPU 1
return self.output_layer(x)
# Example usage:
# model = LargeLanguageModel()
# input_ids = torch.randint(0, 50000, (4, 512)) # Batch size 4, sequence length 512
# outputs = model(input_ids)
# print(outputs.shape) # Expected: [4, 512, 50000]
Best Practices for Model Parallelism
-
Analyze the model before splitting: Identify computational and memory bottlenecks to determine where to split the model.
-
Balance the workload: Try to distribute computation evenly across GPUs.
-
Minimize cross-device communication: Place layers that need to communicate frequently on the same device.
-
Consider using specialized libraries: For large models, consider libraries like DeepSpeed, Megatron-LM, or PyTorch's own parallel computing tools.
-
Combine with other techniques: Model parallelism often works best when combined with data parallelism and other optimization techniques.
Common Challenges and Solutions
Challenge | Solution |
---|---|
Cross-device transfer overhead | Use pipeline parallelism or optimize buffer sizes |
Unbalanced workload | Profile your model and adjust splits for balanced computation |
Synchronization issues | Use proper barriers and synchronization primitives |
Memory leaks | Monitor GPU memory usage and clear caches when necessary |
Combining Model and Data Parallelism
For the most efficient distributed training, you can combine model parallelism with data parallelism:
import torch
import torch.nn as nn
import torch.distributed as dist
import os
# Simple hybrid parallelism example (pseudocode)
def setup_hybrid_parallelism():
# Setup process groups
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# Initialize main process group
dist.init_process_group("nccl")
# Get global rank and world size
global_rank = dist.get_rank()
global_world_size = dist.get_world_size()
# Configure model parallel group (e.g., 2 GPUs per model)
model_parallel_size = 2
data_parallel_size = global_world_size // model_parallel_size
# Create model parallel group
for i in range(data_parallel_size):
ranks = list(range(i * model_parallel_size, (i + 1) * model_parallel_size))
dist.new_group(ranks=ranks, backend="nccl")
# Create data parallel group
for i in range(model_parallel_size):
ranks = list(range(i, global_world_size, model_parallel_size))
dist.new_group(ranks=ranks, backend="nccl")
return global_rank, global_world_size, model_parallel_size, data_parallel_size
Summary
Model parallelism is a powerful technique for training large neural networks that don't fit on a single GPU. In this tutorial, we've covered:
- Basic concepts of model parallelism
- Manual implementation of model splitting across GPUs
- Pipeline parallelism for more efficient processing
- Using PyTorch's built-in distributed tools
- A practical example with a large language model
- Best practices and common challenges
By understanding model parallelism, you can train significantly larger and more complex models than would be possible on a single GPU.
Additional Resources
- PyTorch Distributed Overview
- FSDP (Fully Sharded Data Parallel) Documentation
- DeepSpeed Library for more advanced model parallelism
- Megatron-LM - NVIDIA's implementation for training large transformer models
Exercises
-
Basic Exercise: Modify the basic model parallelism example to use 3 GPUs instead of 2.
-
Intermediate Exercise: Implement a simple image classification model using pipeline parallelism and measure the performance difference compared to a single-GPU implementation.
-
Advanced Exercise: Combine model parallelism with data parallelism to train a large model on a multi-GPU system, managing both model sharding and data distribution.
-
Research Exercise: Experiment with different model splitting strategies and analyze how they affect training time and memory usage.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)