Skip to main content

PyTorch Unit Testing

Introduction

Unit testing is a critical practice in software development that involves testing individual components or "units" of code in isolation to ensure they work as expected. In the context of PyTorch, unit testing helps verify that your models, data transformations, and training procedures behave correctly across different scenarios and input types.

Effective testing can:

  • Catch bugs early in the development process
  • Ensure model behavior remains consistent after refactoring
  • Provide documentation for how your code should behave
  • Enable more confident collaboration with other developers

In this tutorial, we'll explore how to set up and write effective unit tests for PyTorch code. We'll use pytest, a popular Python testing framework that makes it easy to write small, readable tests.

Setting Up Your Testing Environment

Prerequisites

Before we start, make sure you have the following installed:

bash
pip install pytest torch torchvision

Project Structure

A typical PyTorch project with tests might be structured like this:

my_pytorch_project/
├── models/
│ ├── __init__.py
│ ├── simple_cnn.py
│ └── linear_model.py
├── datasets/
│ ├── __init__.py
│ └── custom_dataset.py
├── train.py
├── utils.py
└── tests/
├── __init__.py
├── test_models.py
├── test_datasets.py
└── test_utils.py

Basic PyTorch Unit Testing

Let's start with a simple example. Suppose we have a function that performs a basic operation on a tensor:

python
# utils.py
import torch

def normalize_tensor(x):
"""
Normalize a tensor to have values between 0 and 1
"""
min_val = torch.min(x)
max_val = torch.max(x)
return (x - min_val) / (max_val - min_val)

Here's how we would test this function:

python
# tests/test_utils.py
import torch
import pytest
from utils import normalize_tensor

def test_normalize_tensor():
# Create a test tensor
test_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])

# Apply normalization
normalized = normalize_tensor(test_tensor)

# Check that values are between 0 and 1
assert torch.all(normalized >= 0).item()
assert torch.all(normalized <= 1).item()

# Check that minimum value is 0 and maximum is 1
assert torch.isclose(torch.min(normalized), torch.tensor(0.0))
assert torch.isclose(torch.max(normalized), torch.tensor(1.0))

To run this test, navigate to your project root directory and execute:

bash
python -m pytest tests/test_utils.py -v

The -v flag increases verbosity, showing you which tests passed or failed.

Testing PyTorch Models

Testing neural network models requires a bit more thought. Let's test a simple linear model:

python
# models/linear_model.py
import torch
import torch.nn as nn

class SimpleLinearModel(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.linear = nn.Linear(input_size, output_size)

def forward(self, x):
return self.linear(x)

Here's how we can test this model:

python
# tests/test_models.py
import torch
import pytest
from models.linear_model import SimpleLinearModel

def test_simple_linear_model():
# Test parameters
batch_size = 8
input_size = 10
output_size = 5

# Create an instance of the model
model = SimpleLinearModel(input_size, output_size)

# Create a random input tensor
x = torch.randn(batch_size, input_size)

# Forward pass
output = model(x)

# Check output shape
assert output.shape == (batch_size, output_size)

# Check that the output is different from input (transformation happened)
assert not torch.allclose(x[:, :output_size], output)

# Test with a batch size of 1
single_input = torch.randn(1, input_size)
single_output = model(single_input)
assert single_output.shape == (1, output_size)

Testing with Different Device Types

One common issue in PyTorch is ensuring that your code works across different device types (CPU and GPU). Here's how to write tests that can handle this:

python
# tests/test_device_compatibility.py
import torch
import pytest
from models.simple_cnn import SimpleCNN

@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_model_on_device(device):
# Skip if CUDA is not available and we're testing CUDA
if "cuda" in device and not torch.cuda.is_available():
pytest.skip("CUDA not available, skipping GPU test")

# Model parameters
in_channels = 3
num_classes = 10

# Create model and move to device
model = SimpleCNN(in_channels, num_classes).to(device)

# Create input data and move to device
x = torch.randn(2, in_channels, 32, 32, device=device)

# Forward pass
output = model(x)

# Check output shape and device
assert output.shape == (2, num_classes)
assert output.device.type == device.split(':')[0]

Testing Model Training and Loss Calculation

Let's test that training works correctly:

python
# tests/test_training.py
import torch
import torch.nn as nn
import torch.optim as optim
from models.linear_model import SimpleLinearModel

def test_model_training():
# Create a model
input_size = 10
output_size = 2
model = SimpleLinearModel(input_size, output_size)

# Create some dummy data
x = torch.randn(20, input_size)
y = torch.randint(0, output_size, (20,))

# Setup loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Get initial loss
initial_output = model(x)
initial_loss = criterion(initial_output, y)

# Train for a few steps
for _ in range(5):
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()

# Get final loss
final_output = model(x)
final_loss = criterion(final_output, y)

# Check that loss decreased
assert final_loss < initial_loss

Testing Custom Datasets

If you have custom PyTorch datasets, you should test them too:

python
# tests/test_datasets.py
import torch
import pytest
from datasets.custom_dataset import CustomDataset

def test_custom_dataset():
# Create a small test dataset
dataset = CustomDataset(root="test_data", train=True, download=True)

# Check that dataset has the expected length
assert len(dataset) > 0

# Get a sample
item = dataset[0]

# Check that the item has the expected format (image and label)
image, label = item

# Check image shape (assuming it should be 3x224x224)
assert image.shape == (3, 224, 224)

# Check that label is a scalar
assert isinstance(label, int)

# Test data loader
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=4, shuffle=True, num_workers=0
)

# Check a batch
batch = next(iter(dataloader))
images, labels = batch

assert images.shape == (4, 3, 224, 224)
assert labels.shape == (4,)

Testing for Numerical Stability

Neural networks can sometimes suffer from numerical instability. Let's test for this:

python
# tests/test_numerical_stability.py
import torch
import pytest
from models.linear_model import SimpleLinearModel

def test_numerical_stability():
model = SimpleLinearModel(input_size=10, output_size=5)

# Test with very large inputs
large_input = torch.ones(2, 10) * 1e10
large_output = model(large_input)

# Check for NaN or infinite values
assert not torch.isnan(large_output).any()
assert not torch.isinf(large_output).any()

# Test with very small inputs
small_input = torch.ones(2, 10) * 1e-10
small_output = model(small_input)

assert not torch.isnan(small_output).any()
assert not torch.isinf(small_output).any()

Using Fixtures in Pytest

Fixtures are a powerful feature in pytest that allow you to create reusable test components:

python
# tests/conftest.py
import pytest
import torch
from models.linear_model import SimpleLinearModel

@pytest.fixture
def linear_model():
"""Fixture that returns a simple linear model"""
return SimpleLinearModel(input_size=10, output_size=5)

@pytest.fixture
def sample_batch():
"""Fixture that returns a sample batch of data"""
return torch.randn(8, 10)

Now we can use these fixtures in our tests:

python
# tests/test_with_fixtures.py
import torch

def test_model_with_fixtures(linear_model, sample_batch):
# Use the fixtures directly
output = linear_model(sample_batch)

# Check output shape
assert output.shape == (8, 5)

Testing for Reproducibility

PyTorch operations should be reproducible when a random seed is set:

python
# tests/test_reproducibility.py
import torch
import torch.nn as nn
import random
import numpy as np

def test_reproducibility():
# Function to set seeds
def set_seeds(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Create a simple model
model = nn.Linear(10, 5)

# First run
set_seeds(42)
input_1 = torch.randn(3, 10)
output_1 = model(input_1)

# Second run
set_seeds(42)
input_2 = torch.randn(3, 10)
output_2 = model(input_2)

# Check that the inputs and outputs are identical
assert torch.allclose(input_1, input_2)
assert torch.allclose(output_1, output_2)

Real-World Example: Testing an Image Classifier

Let's put everything together in a comprehensive example testing a more realistic image classifier:

python
# models/image_classifier.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageClassifier(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 4 * 4, 128)
self.fc2 = nn.Linear(128, num_classes)
self.dropout = nn.Dropout(0.5)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x

Now let's test this model comprehensively:

python
# tests/test_image_classifier.py
import torch
import torch.nn as nn
import torch.optim as optim
import pytest
from models.image_classifier import ImageClassifier

@pytest.fixture
def model():
return ImageClassifier(num_classes=10)

@pytest.fixture
def sample_image_batch():
return torch.randn(4, 3, 32, 32)

def test_model_output_shape(model, sample_image_batch):
output = model(sample_image_batch)
assert output.shape == (4, 10)

def test_model_backward_pass(model, sample_image_batch):
# Create target
target = torch.randint(0, 10, (4,))

# Setup optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Forward pass
output = model(sample_image_batch)
loss = criterion(output, target)

# Check loss is not NaN
assert not torch.isnan(loss)

# Backward pass
optimizer.zero_grad()
loss.backward()

# Check gradients
for name, param in model.named_parameters():
# Check that gradients are not None
assert param.grad is not None
# Check that gradients are not NaN or infinite
assert not torch.isnan(param.grad).any()
assert not torch.isinf(param.grad).any()

# Step optimizer
optimizer.step()

def test_model_inference_mode(model, sample_image_batch):
# Test inference mode
model.eval()
with torch.no_grad():
output = model(sample_image_batch)

# Check output again
assert output.shape == (4, 10)
assert not torch.isnan(output).any()

# Check that dropout is not active in eval mode
# We need to test this indirectly - output should be deterministic in eval mode
model.eval()
with torch.no_grad():
output1 = model(sample_image_batch)
output2 = model(sample_image_batch)

# Outputs should be identical in eval mode
assert torch.allclose(output1, output2)

Best Practices for PyTorch Unit Testing

  1. Test Different Input Sizes: Your models should handle various batch sizes, image sizes, etc.

  2. Test Edge Cases: Check behavior with extreme inputs like zero, very large numbers, empty inputs, etc.

  3. Test on Both CPU and GPU: If your model will run on different devices, test it there.

  4. Mock Heavy Operations: Use unittest.mock or pytest's monkeypatch for operations that are too heavy to run in tests.

  5. Separate Unit Tests from Integration Tests: Unit tests should be quick and test isolated components. Save full model training tests for integration testing.

  6. Set Random Seeds: When testing components with randomness, set seeds to ensure reproducibility.

  7. Use CI/CD: Set up continuous integration to run your tests automatically.

Summary

In this tutorial, we've covered the fundamentals of unit testing PyTorch code:

  • Setting up a testing environment with pytest
  • Writing basic tests for tensor operations
  • Testing PyTorch models, datasets, and training loops
  • Ensuring compatibility across different devices
  • Testing for numerical stability and reproducibility
  • A comprehensive real-world example

Effective unit testing is an investment that pays off by catching bugs early, ensuring your models behave as expected, and making your deep learning code more maintainable and reliable.

Additional Resources

Exercises

  1. Write a test for a custom loss function that checks if it correctly calculates the expected values for known inputs.

  2. Create a test suite for a custom dataset class that verifies images are loaded correctly and transformations are applied properly.

  3. Write a test that verifies model weights are properly saved and loaded.

  4. Create a test that checks your model produces the same output when run with the same input twice (with proper seeding).

  5. Write a test to verify that your model's output probability distribution sums to 1 after softmax is applied.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)