PyTorch Custom Modules
Introduction
PyTorch's modular design is one of its greatest strengths, allowing you to build complex neural networks by combining simpler building blocks. While PyTorch provides many pre-built layers and functions through its torch.nn
package, you'll often need to create custom components for your specific requirements.
In this tutorial, we'll learn how to create custom modules by extending PyTorch's nn.Module
class. Custom modules allow you to:
- Encapsulate complex logic in reusable components
- Create novel network architectures not available in standard libraries
- Implement specialized layers with custom forward and backward passes
- Organize your code more effectively by grouping related operations
Whether you're implementing a paper from scratch or designing your own neural network architecture, understanding how to create custom modules is an essential PyTorch skill.
Understanding nn.Module
The nn.Module
class is the foundation of all neural network modules in PyTorch. It provides:
- Parameter management: Automatically tracks and registers parameters
- Device management: Easily moves your model between CPU and GPU
- Serialization: Supports saving and loading model states
- Module composition: Allows building hierarchical structures
Every custom module in PyTorch should inherit from nn.Module
and implement at least two methods:
__init__()
: Initialize the module and define its componentsforward()
: Define how inputs are processed to produce outputs
Let's start with a simple example:
import torch
import torch.nn as nn
class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super(MyLinear, self).__init__()
# Create and register trainable parameters
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x):
# Define the computation
return torch.matmul(x, self.weight.t()) + self.bias
# Create an instance of our custom module
layer = MyLinear(10, 5)
input_tensor = torch.randn(3, 10) # Batch of 3 samples, 10 features each
output = layer(input_tensor)
print(output.shape) # Output: torch.Size([3, 5])
In this example, we created a simple linear layer that performs the operation y = xW^T + b
. PyTorch will automatically track the weights and biases as trainable parameters.
Creating a Custom Activation Function
Let's implement a custom activation function called "Swish", defined as f(x) = x * sigmoid(x)
:
class Swish(nn.Module):
def __init__(self, beta=1.0):
super(Swish, self).__init__()
# Beta can be a fixed value or a trainable parameter
self.beta = nn.Parameter(torch.tensor([beta]))
def forward(self, x):
return x * torch.sigmoid(self.beta * x)
# Test our custom activation
swish = Swish()
x = torch.linspace(-5, 5, 10)
y = swish(x)
print("Input:", x)
print("Output:", y)
Output:
Input: tensor([-5.0000, -3.8889, -2.7778, -1.6667, -0.5556, 0.5556, 1.6667, 2.7778, 3.8889, 5.0000])
Output: tensor([-0.0067, -0.0206, -0.0456, -0.0574, 0.0350, 0.4015, 1.4094, 2.6957, 3.8635, 4.9933])
Building a Complex Custom Module
Now let's create a more complex module that combines multiple operations. We'll implement a custom residual block, a key component in ResNet architectures:
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
# Main path
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# Shortcut connection (skip connection)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
# Store input for skip connection
identity = x
# Main path
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# Add skip connection
out += self.shortcut(identity)
out = self.relu(out)
return out
# Test our residual block
block = ResidualBlock(in_channels=64, out_channels=128, stride=2)
x = torch.randn(1, 64, 32, 32) # [batch_size, channels, height, width]
y = block(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
Output:
Input shape: torch.Size([1, 64, 32, 32])
Output shape: torch.Size([1, 128, 16, 16])
Nested Modules
One of the strengths of PyTorch's module system is the ability to nest modules within each other. Let's create a custom network using our previously defined components:
class CustomNetwork(nn.Module):
def __init__(self, input_dim=784, hidden_dim=128, num_classes=10):
super(CustomNetwork, self).__init__()
# Feature extraction
self.features = nn.Sequential(
nn.Linear(input_dim, hidden_dim * 2),
Swish(),
nn.Linear(hidden_dim * 2, hidden_dim),
Swish()
)
# Classifier
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
# Flatten the input if it's not already flat
if len(x.shape) > 2:
batch_size = x.size(0)
x = x.view(batch_size, -1)
# Pass through feature extractor
features = self.features(x)
# Classify
logits = self.classifier(features)
return logits
# Create and test the network
model = CustomNetwork()
sample_input = torch.randn(32, 784) # 32 samples of flattened MNIST images
output = model(sample_input)
print(f"Network output shape: {output.shape}") # Should be [32, 10]
Output:
Network output shape: torch.Size([32, 10])
Implementing Custom Loss Functions
You can also create custom loss functions as modules:
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.ce_loss = nn.CrossEntropyLoss(reduction='none')
def forward(self, inputs, targets):
ce_loss = self.ce_loss(inputs, targets)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else: # 'none'
return focal_loss
# Example usage
criterion = FocalLoss(gamma=2)
logits = torch.randn(5, 3) # 5 samples, 3 classes
targets = torch.tensor([0, 1, 2, 1, 0])
loss = criterion(logits, targets)
print(f"Focal loss: {loss.item()}")
Modules with Learnable Parameters vs. Functional Operations
Sometimes you might wonder whether to implement an operation as a module or just use a function. Here's a general guideline:
-
Use a module if:
- The operation has learnable parameters
- You want to toggle training modes (like dropout or batch normalization)
- You need to maintain state across calls
-
Use a function if:
- The operation is stateless (no parameters, no modes)
- It's a simple mathematical operation
For example, ReLU can be implemented either way:
# As a module
class ReLUModule(nn.Module):
def forward(self, x):
return torch.relu(x)
# Or as a functional call
import torch.nn.functional as F
def relu_function(x):
return F.relu(x)
For simple operations like ReLU, the functional approach is often more concise, but using modules provides consistency with the rest of your network structure.
Accessing Module Parameters
PyTorch makes it easy to access and manipulate the parameters of your custom modules:
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleModel()
# Iterate through named parameters
print("Named parameters:")
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")
# Get all parameters as a list
params_list = list(model.parameters())
print(f"\nNumber of parameter tensors: {len(params_list)}")
Output:
Named parameters:
fc1.weight: torch.Size([20, 10])
fc1.bias: torch.Size([20])
fc2.weight: torch.Size([5, 20])
fc2.bias: torch.Size([5])
Number of parameter tensors: 4
Real-World Example: Attention Mechanism
Let's implement a self-attention module, a key component in many modern neural networks including transformers:
class SelfAttention(nn.Module):
def __init__(self, embedding_dim, heads=8):
super(SelfAttention, self).__init__()
self.embedding_dim = embedding_dim
self.heads = heads
self.head_dim = embedding_dim // heads
assert (self.head_dim * heads == embedding_dim), "Embedding dimension must be divisible by number of heads"
# Linear projections for Q, K, V
self.q_proj = nn.Linear(embedding_dim, embedding_dim)
self.k_proj = nn.Linear(embedding_dim, embedding_dim)
self.v_proj = nn.Linear(embedding_dim, embedding_dim)
# Output projection
self.out_proj = nn.Linear(embedding_dim, embedding_dim)
def forward(self, x):
# x shape: [batch_size, seq_length, embedding_dim]
batch_size, seq_length, _ = x.size()
# Project inputs to queries, keys and values
q = self.q_proj(x).reshape(batch_size, seq_length, self.heads, self.head_dim)
k = self.k_proj(x).reshape(batch_size, seq_length, self.heads, self.head_dim)
v = self.v_proj(x).reshape(batch_size, seq_length, self.heads, self.head_dim)
# Transpose for attention calculation
q = q.transpose(1, 2) # [batch_size, heads, seq_length, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Calculate attention scores
attention = torch.matmul(q, k.transpose(2, 3)) # [batch, heads, seq_len, seq_len]
attention = attention / (self.head_dim ** 0.5) # Scale
attention = torch.softmax(attention, dim=-1)
# Apply attention to values
out = torch.matmul(attention, v) # [batch, heads, seq_len, head_dim]
out = out.transpose(1, 2) # [batch, seq_len, heads, head_dim]
out = out.reshape(batch_size, seq_length, self.embedding_dim)
# Final projection
out = self.out_proj(out)
return out
# Test the attention module
batch_size = 2
seq_length = 10
embedding_dim = 256
attention = SelfAttention(embedding_dim=embedding_dim, heads=8)
x = torch.randn(batch_size, seq_length, embedding_dim)
output = attention(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
Output:
Input shape: torch.Size([2, 10, 256])
Output shape: torch.Size([2, 10, 256])
Best Practices for Custom Modules
- Consistent initialization: Always call the parent class's
__init__
method withsuper().__init__()
- Register parameters correctly: Use
nn.Parameter
for trainable parameters - Module naming: Use descriptive names for submodules to help with debugging
- Forward method signature: Keep your
forward()
method signature consistent and well-documented - Input validation: Include input checks in your forward method, especially during development
- Device compatibility: Make your modules device-agnostic by using the input tensor's device
class BestPracticeModule(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__() # Use the shorter super() syntax in Python 3
# Validate inputs
assert input_dim > 0, "Input dimension must be positive"
assert hidden_dim > 0, "Hidden dimension must be positive"
# Store configuration for reference
self.config = {
'input_dim': input_dim,
'hidden_dim': hidden_dim,
}
# Create submodules with descriptive names
self.feature_extractor = nn.Linear(input_dim, hidden_dim)
self.activation = nn.ReLU()
self.output_projector = nn.Linear(hidden_dim, 1)
def forward(self, x):
"""Process input tensor through the module.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, input_dim]
Returns:
torch.Tensor: Output tensor of shape [batch_size, 1]
"""
# Validate input
if x.dim() != 2 or x.size(1) != self.config['input_dim']:
raise ValueError(f"Expected input shape [batch_size, {self.config['input_dim']}], "
f"got {list(x.shape)}")
# Forward pass
h = self.feature_extractor(x)
h = self.activation(h)
output = self.output_projector(h)
return output
Debugging Custom Modules
When your custom module doesn't work as expected, try these debugging strategies:
- Print intermediate shapes: Add print statements in the forward method
- Use hooks: Register forward/backward hooks to inspect tensors
- Check parameter initialization: Verify parameters are created properly
- Validate gradients: Check that gradients are flowing correctly
Here's an example with hooks:
def hook_fn(module, input, output):
print(f"Module: {module.__class__.__name__}")
print(f"Input shapes: {[x.shape if isinstance(x, torch.Tensor) else x for x in input]}")
print(f"Output shape: {output.shape}")
print("---")
model = CustomNetwork()
# Register hooks to the first layer
model.features[0].register_forward_hook(hook_fn)
# Run a forward pass
input_tensor = torch.randn(2, 784)
output = model(input_tensor)
Summary
In this tutorial, you've learned:
- How to create custom modules by extending
nn.Module
- The importance of the
__init__()
andforward()
methods - How to create and register parameters and submodules
- Building complex modules by nesting simpler ones
- Implementing custom activation functions and loss functions
- Best practices for designing and debugging custom modules
Custom modules are a powerful tool in your PyTorch toolkit. They allow you to encapsulate complex logic, create reusable components, and build novel architectures not available in standard libraries. As you develop more sophisticated deep learning models, the ability to create custom modules will become increasingly valuable.
Exercises
- Create a custom module that implements the LeakyReLU activation function with a learnable negative slope parameter.
- Implement a custom residual block that uses depthwise separable convolutions instead of standard convolutions.
- Create a custom module that combines batch normalization and dropout in a single component.
- Build a custom attention module that implements multi-head cross-attention between two sequences.
- Design a custom loss function module that combines L1 and L2 losses with learnable weights.
Additional Resources
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)