PyTorch nn Module
Introduction
The torch.nn
module is the cornerstone of neural network development in PyTorch. It provides all the building blocks necessary to design complex neural network architectures, from simple linear layers to convolutional networks and recurrent models. This module abstracts away many implementation details, allowing you to focus on the architecture design rather than low-level operations.
In this tutorial, you'll learn:
- What the PyTorch nn module is and why it's important
- The key components of nn.Module (the base class)
- How to create custom neural network layers
- How to build complete neural networks
- Best practices for using nn modules in your deep learning projects
Understanding nn.Module - The Foundation
What is nn.Module?
nn.Module
is the base class for all neural network modules in PyTorch. It provides a structured way to represent a layer or a complete neural network that performs computations on tensors.
import torch
import torch.nn as nn
# Creating a simple module
class MySimpleModule(nn.Module):
def __init__(self):
super().__init__()
# This is where you define your layers
def forward(self, x):
# This is where you define the computation flow
return x
Key Features of nn.Module
- Parameter Management: Automatically tracks and manages model parameters
- Module Composition: Allows nesting modules within modules
- GPU/CPU Transfer: Easily move models between devices
- Training/Evaluation Modes: Switch between training and evaluation with
model.train()
andmodel.eval()
Building Blocks: Common nn Components
Linear Layers
The most basic component is the fully connected layer, implemented as nn.Linear
:
import torch
import torch.nn as nn
# Define a linear layer (y = Wx + b)
linear = nn.Linear(in_features=10, out_features=5)
# Input tensor
x = torch.randn(32, 10) # 32 samples, 10 features each
# Pass through the layer
y = linear(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
# Output:
# Input shape: torch.Size([32, 10])
# Output shape: torch.Size([32, 5])
Activation Functions
Activation functions introduce non-linearity into networks:
import torch
import torch.nn as nn
# Create some common activation functions
relu = nn.ReLU()
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()
# Input tensor
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
# Apply activations
print(f"Original: {x}")
print(f"ReLU: {relu(x)}")
print(f"Sigmoid: {sigmoid(x)}")
print(f"Tanh: {tanh(x)}")
# Output:
# Original: tensor([-2., -1., 0., 1., 2.])
# ReLU: tensor([0., 0., 0., 1., 2.])
# Sigmoid: tensor([0.1192, 0.2689, 0.5000, 0.7311, 0.8808])
# Tanh: tensor([-0.9640, -0.7616, 0.0000, 0.7616, 0.9640])
Convolutional Layers
For processing grid-like data such as images:
import torch
import torch.nn as nn
# Define a 2D convolutional layer
conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
# Input tensor (batch_size, channels, height, width)
x = torch.randn(1, 3, 28, 28) # Single RGB image of size 28x28
# Apply convolution
y = conv(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
# Output:
# Input shape: torch.Size([1, 3, 28, 28])
# Output shape: torch.Size([1, 16, 28, 28])
Pooling Layers
Reduce spatial dimensions while retaining important features:
import torch
import torch.nn as nn
# Max pooling layer
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Average pooling layer
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
# Input tensor
x = torch.randn(1, 1, 4, 4)
print(f"Original:\n{x.squeeze()}")
# Apply pooling
max_result = max_pool(x)
avg_result = avg_pool(x)
print(f"After MaxPool2d:\n{max_result.squeeze()}")
print(f"After AvgPool2d:\n{avg_result.squeeze()}")
Normalization Layers
Stabilize and accelerate training:
import torch
import torch.nn as nn
# Batch normalization layer
batch_norm = nn.BatchNorm1d(num_features=5)
# Input tensor (batch_size, features)
x = torch.randn(32, 5)
# Apply batch normalization
y = batch_norm(x)
print(f"Input mean: {x.mean(dim=0)}")
print(f"Input std: {x.std(dim=0)}")
print(f"Output mean: {y.mean(dim=0)}") # Close to 0
print(f"Output std: {y.std(dim=0)}") # Close to 1
Creating Your First Neural Network
Let's build a simple image classifier using the MNIST dataset:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleConvNet(nn.Module):
def __init__(self):
super().__init__()
# First convolutional layer
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
# Second convolutional layer
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
# Max pooling layer
self.pool = nn.MaxPool2d(kernel_size=2)
# Fully connected layers
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
# Dropout for regularization
self.dropout = nn.Dropout(0.25)
def forward(self, x):
# First conv block
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)
# Second conv block
x = self.conv2(x)
x = F.relu(x)
x = self.pool(x)
# Flatten the output for the fully connected layer
x = torch.flatten(x, 1)
# Fully connected layers with dropout
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
# Output layer with softmax activation
output = F.log_softmax(x, dim=1)
return output
# Create the model
model = SimpleConvNet()
print(model)
# Sample input
sample = torch.randn(1, 1, 28, 28)
output = model(sample)
print(f"Output shape: {output.shape}") # Should be [1, 10] (one probability per digit)
Container Modules
PyTorch provides container modules to organize layers:
Sequential
The simplest way to stack layers:
import torch
import torch.nn as nn
# Creating a network using Sequential
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.LogSoftmax(dim=1)
)
# Forward pass
input_data = torch.randn(32, 784) # 32 samples, 784 features (28x28 flattened)
output = model(input_data)
print(f"Output shape: {output.shape}") # Should be [32, 10]
ModuleList
For when you need more control over the flow:
import torch
import torch.nn as nn
class DynamicNetwork(nn.Module):
def __init__(self, layer_sizes):
super().__init__()
# Create a ModuleList of linear layers
self.layers = nn.ModuleList([
nn.Linear(layer_sizes[i], layer_sizes[i+1])
for i in range(len(layer_sizes) - 1)
])
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
# Apply ReLU to all but the last layer
if i < len(self.layers) - 1:
x = torch.relu(x)
return x
# Create a network with custom architecture
model = DynamicNetwork([784, 512, 256, 128, 10])
# Forward pass
input_data = torch.randn(32, 784)
output = model(input_data)
print(f"Output shape: {output.shape}")
ModuleDict
Organize modules by name:
import torch
import torch.nn as nn
class ModelWithBranches(nn.Module):
def __init__(self):
super().__init__()
self.branches = nn.ModuleDict({
'branch1': nn.Linear(100, 10),
'branch2': nn.Linear(100, 20),
'branch3': nn.Linear(100, 30)
})
def forward(self, x, branch_name):
return self.branches[branch_name](x)
# Create the model
model = ModelWithBranches()
# Sample input
x = torch.randn(5, 100)
# Get output from different branches
out1 = model(x, 'branch1')
out2 = model(x, 'branch2')
out3 = model(x, 'branch3')
print(f"Branch 1 output shape: {out1.shape}") # [5, 10]
print(f"Branch 2 output shape: {out2.shape}") # [5, 20]
print(f"Branch 3 output shape: {out3.shape}") # [5, 30]
Practical Example: Transfer Learning
Let's see how to use pre-trained models from torchvision
with the nn module:
import torch
import torch.nn as nn
import torchvision.models as models
# Load a pre-trained ResNet model
resnet = models.resnet18(pretrained=True)
# Freeze all the parameters to prevent training
for param in resnet.parameters():
param.requires_grad = False
# Replace the final fully connected layer
# ResNet18's fc layer has input features of 512
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 10) # 10 classes for our custom task
print(f"Modified ResNet model:")
print(f"Final layer: {resnet.fc}")
# Now only the final layer will be trained
trainable_params = [p for p in resnet.parameters() if p.requires_grad]
print(f"Number of trainable parameters: {len(trainable_params)}")
Common Operations with nn.Module
Saving and Loading Models
import torch
import torch.nn as nn
# Define a simple model
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1)
)
# Save the model
torch.save(model.state_dict(), 'simple_model.pth')
# Load the model
new_model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1)
)
new_model.load_state_dict(torch.load('simple_model.pth'))
new_model.eval() # Set to evaluation mode
print("Model successfully saved and loaded!")
Moving Models Between Devices
import torch
import torch.nn as nn
# Create a model
model = nn.Linear(10, 5)
# Check if CUDA (GPU) is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Move model to the appropriate device
model.to(device)
# Create input tensor on the same device
input_data = torch.randn(3, 10, device=device)
# Forward pass (computation happens on the device)
output = model(input_data)
print(f"Output shape: {output.shape}")
Getting Parameters and Buffers
import torch
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3)
self.bn = nn.BatchNorm2d(16)
self.fc = nn.Linear(16 * 26 * 26, 10)
def forward(self, x):
return self.fc(self.bn(self.conv(x)).flatten(1))
# Create the model
model = CustomModel()
# Print model parameters
print("Model Parameters:")
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")
# Print model buffers (like running_mean in BatchNorm)
print("\nModel Buffers:")
for name, buf in model.named_buffers():
print(f"{name}: {buf.shape}")
Best Practices
-
Organize Complex Models:
- Break down complex models into submodules for better organization
- Use nn.Sequential for linear sequences of layers
-
Parameter Initialization:
- Use provided initialization methods like
nn.init.xavier_uniform_
- Use provided initialization methods like
def weight_init(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
model = SimpleConvNet()
model.apply(weight_init)
-
Use Built-in Layers:
- Prefer built-in nn modules over manual implementations for better performance
-
Debugging:
- Print intermediate shapes to debug network flow
- Use
torchsummary
for a visual representation of your model
# pip install torchsummary
from torchsummary import summary
model = SimpleConvNet()
summary(model, (1, 28, 28))
Summary
The torch.nn
module is a powerful toolset for building neural networks in PyTorch. We've covered:
- The foundation:
nn.Module
- Basic building blocks (Linear, Conv2d, activation functions)
- Container modules (Sequential, ModuleList, ModuleDict)
- Creating custom neural networks
- Common operations (saving/loading, device management)
- Best practices for organizing and debugging your models
With these tools, you can create anything from simple classifiers to complex, state-of-the-art deep learning architectures.
Additional Resources
Exercises
- Create a simple feedforward network for the MNIST dataset using only Linear layers
- Implement a custom layer that applies a specific mathematical operation
- Build a CNN for image classification on the CIFAR-10 dataset
- Implement transfer learning using a pre-trained model from torchvision for a custom dataset
- Create a model with multiple input branches that combines features at a later stage
By mastering the nn module, you've taken a significant step toward building sophisticated deep learning models in PyTorch!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)