Skip to main content

PyTorch nn Module

Introduction

The torch.nn module is the cornerstone of neural network development in PyTorch. It provides all the building blocks necessary to design complex neural network architectures, from simple linear layers to convolutional networks and recurrent models. This module abstracts away many implementation details, allowing you to focus on the architecture design rather than low-level operations.

In this tutorial, you'll learn:

  • What the PyTorch nn module is and why it's important
  • The key components of nn.Module (the base class)
  • How to create custom neural network layers
  • How to build complete neural networks
  • Best practices for using nn modules in your deep learning projects

Understanding nn.Module - The Foundation

What is nn.Module?

nn.Module is the base class for all neural network modules in PyTorch. It provides a structured way to represent a layer or a complete neural network that performs computations on tensors.

python
import torch
import torch.nn as nn

# Creating a simple module
class MySimpleModule(nn.Module):
def __init__(self):
super().__init__()
# This is where you define your layers

def forward(self, x):
# This is where you define the computation flow
return x

Key Features of nn.Module

  1. Parameter Management: Automatically tracks and manages model parameters
  2. Module Composition: Allows nesting modules within modules
  3. GPU/CPU Transfer: Easily move models between devices
  4. Training/Evaluation Modes: Switch between training and evaluation with model.train() and model.eval()

Building Blocks: Common nn Components

Linear Layers

The most basic component is the fully connected layer, implemented as nn.Linear:

python
import torch
import torch.nn as nn

# Define a linear layer (y = Wx + b)
linear = nn.Linear(in_features=10, out_features=5)

# Input tensor
x = torch.randn(32, 10) # 32 samples, 10 features each

# Pass through the layer
y = linear(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
# Output:
# Input shape: torch.Size([32, 10])
# Output shape: torch.Size([32, 5])

Activation Functions

Activation functions introduce non-linearity into networks:

python
import torch
import torch.nn as nn

# Create some common activation functions
relu = nn.ReLU()
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()

# Input tensor
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])

# Apply activations
print(f"Original: {x}")
print(f"ReLU: {relu(x)}")
print(f"Sigmoid: {sigmoid(x)}")
print(f"Tanh: {tanh(x)}")

# Output:
# Original: tensor([-2., -1., 0., 1., 2.])
# ReLU: tensor([0., 0., 0., 1., 2.])
# Sigmoid: tensor([0.1192, 0.2689, 0.5000, 0.7311, 0.8808])
# Tanh: tensor([-0.9640, -0.7616, 0.0000, 0.7616, 0.9640])

Convolutional Layers

For processing grid-like data such as images:

python
import torch
import torch.nn as nn

# Define a 2D convolutional layer
conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)

# Input tensor (batch_size, channels, height, width)
x = torch.randn(1, 3, 28, 28) # Single RGB image of size 28x28

# Apply convolution
y = conv(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
# Output:
# Input shape: torch.Size([1, 3, 28, 28])
# Output shape: torch.Size([1, 16, 28, 28])

Pooling Layers

Reduce spatial dimensions while retaining important features:

python
import torch
import torch.nn as nn

# Max pooling layer
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

# Average pooling layer
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

# Input tensor
x = torch.randn(1, 1, 4, 4)
print(f"Original:\n{x.squeeze()}")

# Apply pooling
max_result = max_pool(x)
avg_result = avg_pool(x)

print(f"After MaxPool2d:\n{max_result.squeeze()}")
print(f"After AvgPool2d:\n{avg_result.squeeze()}")

Normalization Layers

Stabilize and accelerate training:

python
import torch
import torch.nn as nn

# Batch normalization layer
batch_norm = nn.BatchNorm1d(num_features=5)

# Input tensor (batch_size, features)
x = torch.randn(32, 5)

# Apply batch normalization
y = batch_norm(x)

print(f"Input mean: {x.mean(dim=0)}")
print(f"Input std: {x.std(dim=0)}")
print(f"Output mean: {y.mean(dim=0)}") # Close to 0
print(f"Output std: {y.std(dim=0)}") # Close to 1

Creating Your First Neural Network

Let's build a simple image classifier using the MNIST dataset:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleConvNet(nn.Module):
def __init__(self):
super().__init__()
# First convolutional layer
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
# Second convolutional layer
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
# Max pooling layer
self.pool = nn.MaxPool2d(kernel_size=2)
# Fully connected layers
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
# Dropout for regularization
self.dropout = nn.Dropout(0.25)

def forward(self, x):
# First conv block
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)

# Second conv block
x = self.conv2(x)
x = F.relu(x)
x = self.pool(x)

# Flatten the output for the fully connected layer
x = torch.flatten(x, 1)

# Fully connected layers with dropout
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)

# Output layer with softmax activation
output = F.log_softmax(x, dim=1)
return output

# Create the model
model = SimpleConvNet()
print(model)

# Sample input
sample = torch.randn(1, 1, 28, 28)
output = model(sample)
print(f"Output shape: {output.shape}") # Should be [1, 10] (one probability per digit)

Container Modules

PyTorch provides container modules to organize layers:

Sequential

The simplest way to stack layers:

python
import torch
import torch.nn as nn

# Creating a network using Sequential
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.LogSoftmax(dim=1)
)

# Forward pass
input_data = torch.randn(32, 784) # 32 samples, 784 features (28x28 flattened)
output = model(input_data)
print(f"Output shape: {output.shape}") # Should be [32, 10]

ModuleList

For when you need more control over the flow:

python
import torch
import torch.nn as nn

class DynamicNetwork(nn.Module):
def __init__(self, layer_sizes):
super().__init__()
# Create a ModuleList of linear layers
self.layers = nn.ModuleList([
nn.Linear(layer_sizes[i], layer_sizes[i+1])
for i in range(len(layer_sizes) - 1)
])

def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
# Apply ReLU to all but the last layer
if i < len(self.layers) - 1:
x = torch.relu(x)
return x

# Create a network with custom architecture
model = DynamicNetwork([784, 512, 256, 128, 10])

# Forward pass
input_data = torch.randn(32, 784)
output = model(input_data)
print(f"Output shape: {output.shape}")

ModuleDict

Organize modules by name:

python
import torch
import torch.nn as nn

class ModelWithBranches(nn.Module):
def __init__(self):
super().__init__()
self.branches = nn.ModuleDict({
'branch1': nn.Linear(100, 10),
'branch2': nn.Linear(100, 20),
'branch3': nn.Linear(100, 30)
})

def forward(self, x, branch_name):
return self.branches[branch_name](x)

# Create the model
model = ModelWithBranches()

# Sample input
x = torch.randn(5, 100)

# Get output from different branches
out1 = model(x, 'branch1')
out2 = model(x, 'branch2')
out3 = model(x, 'branch3')

print(f"Branch 1 output shape: {out1.shape}") # [5, 10]
print(f"Branch 2 output shape: {out2.shape}") # [5, 20]
print(f"Branch 3 output shape: {out3.shape}") # [5, 30]

Practical Example: Transfer Learning

Let's see how to use pre-trained models from torchvision with the nn module:

python
import torch
import torch.nn as nn
import torchvision.models as models

# Load a pre-trained ResNet model
resnet = models.resnet18(pretrained=True)

# Freeze all the parameters to prevent training
for param in resnet.parameters():
param.requires_grad = False

# Replace the final fully connected layer
# ResNet18's fc layer has input features of 512
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 10) # 10 classes for our custom task

print(f"Modified ResNet model:")
print(f"Final layer: {resnet.fc}")

# Now only the final layer will be trained
trainable_params = [p for p in resnet.parameters() if p.requires_grad]
print(f"Number of trainable parameters: {len(trainable_params)}")

Common Operations with nn.Module

Saving and Loading Models

python
import torch
import torch.nn as nn

# Define a simple model
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1)
)

# Save the model
torch.save(model.state_dict(), 'simple_model.pth')

# Load the model
new_model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1)
)
new_model.load_state_dict(torch.load('simple_model.pth'))
new_model.eval() # Set to evaluation mode

print("Model successfully saved and loaded!")

Moving Models Between Devices

python
import torch
import torch.nn as nn

# Create a model
model = nn.Linear(10, 5)

# Check if CUDA (GPU) is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move model to the appropriate device
model.to(device)

# Create input tensor on the same device
input_data = torch.randn(3, 10, device=device)

# Forward pass (computation happens on the device)
output = model(input_data)
print(f"Output shape: {output.shape}")

Getting Parameters and Buffers

python
import torch
import torch.nn as nn

class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3)
self.bn = nn.BatchNorm2d(16)
self.fc = nn.Linear(16 * 26 * 26, 10)

def forward(self, x):
return self.fc(self.bn(self.conv(x)).flatten(1))

# Create the model
model = CustomModel()

# Print model parameters
print("Model Parameters:")
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")

# Print model buffers (like running_mean in BatchNorm)
print("\nModel Buffers:")
for name, buf in model.named_buffers():
print(f"{name}: {buf.shape}")

Best Practices

  1. Organize Complex Models:

    • Break down complex models into submodules for better organization
    • Use nn.Sequential for linear sequences of layers
  2. Parameter Initialization:

    • Use provided initialization methods like nn.init.xavier_uniform_
python
def weight_init(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)

model = SimpleConvNet()
model.apply(weight_init)
  1. Use Built-in Layers:

    • Prefer built-in nn modules over manual implementations for better performance
  2. Debugging:

    • Print intermediate shapes to debug network flow
    • Use torchsummary for a visual representation of your model
python
# pip install torchsummary
from torchsummary import summary

model = SimpleConvNet()
summary(model, (1, 28, 28))

Summary

The torch.nn module is a powerful toolset for building neural networks in PyTorch. We've covered:

  • The foundation: nn.Module
  • Basic building blocks (Linear, Conv2d, activation functions)
  • Container modules (Sequential, ModuleList, ModuleDict)
  • Creating custom neural networks
  • Common operations (saving/loading, device management)
  • Best practices for organizing and debugging your models

With these tools, you can create anything from simple classifiers to complex, state-of-the-art deep learning architectures.

Additional Resources

  1. PyTorch Official nn Documentation
  2. PyTorch Tutorials
  3. Deep Learning with PyTorch: A 60 Minute Blitz

Exercises

  1. Create a simple feedforward network for the MNIST dataset using only Linear layers
  2. Implement a custom layer that applies a specific mathematical operation
  3. Build a CNN for image classification on the CIFAR-10 dataset
  4. Implement transfer learning using a pre-trained model from torchvision for a custom dataset
  5. Create a model with multiple input branches that combines features at a later stage

By mastering the nn module, you've taken a significant step toward building sophisticated deep learning models in PyTorch!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)