PyTorch Lightning
PyTorch Lightning is a lightweight wrapper for PyTorch that helps you organize your code and reduce boilerplate without sacrificing flexibility. In this tutorial, you'll learn how Lightning can transform your deep learning workflow and make your code more readable, modular, and scalable.
Introduction to PyTorch Lightning
Have you ever written PyTorch code that became messy and hard to maintain as your project grew? Perhaps you found yourself copying and pasting training loops, validation steps, and optimization code across different projects? PyTorch Lightning addresses these challenges by providing a structured framework that separates research code from engineering code.
Lightning was created by William Falcon to solve a simple problem: make deep learning code more organized and reproducible without sacrificing flexibility.
Why Use PyTorch Lightning?
- Cleaner, more organized code: Lightning enforces a structured way to organize your PyTorch code
- Less boilerplate: Standard training loops, validation, and testing are handled for you
- Built-in best practices: Automatic checkpointing, logging, early stopping, etc.
- Hardware agnostic: The same code can run on CPU, GPU, TPU, or multi-GPU setups
- Reproducibility: Lightning handles seeds and deterministic behavior
Getting Started
First, let's install PyTorch Lightning:
pip install pytorch-lightning
Basic Structure: The LightningModule
The core of Lightning is the LightningModule
, which organizes your code into specific methods:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
class LitMNISTModel(pl.LightningModule):
def __init__(self):
super().__init__()
# Define model architecture
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 64)
self.layer_3 = nn.Linear(64, 10)
def forward(self, x):
# Forward pass (used for inference)
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1) # Flatten the input
x = F.relu(self.layer_1(x))
x = F.relu(self.layer_2(x))
x = self.layer_3(x)
return x
def training_step(self, batch, batch_idx):
# Training loop step
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
# Validation loop step
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('val_loss', loss)
self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
# Test loop step
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('test_loss', loss)
self.log('test_acc', acc)
def configure_optimizers(self):
# Define optimizers and LR schedulers
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
return optimizer
Data Module
While optional, Lightning provides a LightningDataModule
to standardize data handling:
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
# Download data (runs once on one GPU)
MNIST(root='./data', train=True, download=True)
MNIST(root='./data', train=False, download=True)
def setup(self, stage=None):
# Prepare data for each stage (fit, test, etc)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
if stage == 'fit' or stage is None:
mnist_train = MNIST(root='./data', train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test' or stage is None:
self.mnist_test = MNIST(root='./data', train=False, transform=transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
Training the Model
With Lightning, training is clean and straightforward:
# Initialize the model and data module
model = LitMNISTModel()
data_module = MNISTDataModule()
# Create a trainer
trainer = pl.Trainer(max_epochs=5, accelerator='auto')
# Train the model
trainer.fit(model, data_module)
# Test the model
trainer.test(model, data_module)
Output:
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
---------------------------
0 | layer_1 | Linear | 100 K
1 | layer_2 | Linear | 8.3 K
2 | layer_3 | Linear | 650
---------------------------
109 K Trainable params
0 Non-trainable params
109 K Total params
0.438 Total estimated model params size (MB)
Epoch 0: 100%|██████████| 1719/1719 [00:03<00:00, 452.80it/s, loss=0.139, v_num=1]
Epoch 1: 100%|██████████| 1719/1719 [00:03<00:00, 459.18it/s, loss=0.0486, v_num=1]
Epoch 2: 100%|██████████| 1719/1719 [00:03<00:00, 467.50it/s, loss=0.0398, v_num=1]
Epoch 3: 100%|██████████| 1719/1719 [00:03<00:00, 456.12it/s, loss=0.0294, v_num=1]
Epoch 4: 100%|██████████| 1719/1719 [00:03<00:00, 463.57it/s, loss=0.0228, v_num=1]
Testing: 100%|██████████| 313/313 [00:00<00:00, 484.56it/s]
────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────
test_acc 0.9789
test_loss 0.0659
────────────────────────────────────────────────────────────
Advanced Features
Callbacks
Lightning provides callbacks to customize the training process:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
# Save the best model based on validation accuracy
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
dirpath='./checkpoints',
filename='mnist-{epoch:02d}-{val_acc:.2f}',
save_top_k=3,
mode='max',
)
# Stop training when validation loss stops improving
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
verbose=True,
mode='min'
)
# Use callbacks in the trainer
trainer = pl.Trainer(
max_epochs=10,
callbacks=[checkpoint_callback, early_stop_callback],
accelerator='auto'
)
Logging
Lightning integrates with popular logging frameworks:
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger("tb_logs", name="mnist_model")
trainer = pl.Trainer(
max_epochs=5,
logger=logger,
accelerator='auto'
)
Multi-GPU Training
Running on multiple GPUs is as simple as changing a parameter:
trainer = pl.Trainer(
max_epochs=5,
devices=2, # Use 2 GPUs
accelerator='gpu',
strategy='ddp' # Distributed Data Parallel
)
Real-world Example: Image Classification with Transfer Learning
Let's build a more practical example using transfer learning for image classification:
import os
import torch
import pytorch_lightning as pl
import torchvision.models as models
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
class TransferLearningModel(pl.LightningModule):
def __init__(self, num_classes, learning_rate=0.001):
super().__init__()
self.save_hyperparameters()
# Load a pre-trained ResNet model
self.model = models.resnet18(pretrained=True)
# Freeze the parameters
for param in self.model.parameters():
param.requires_grad = False
# Replace the final fully connected layer
num_features = self.model.fc.in_features
self.model.fc = torch.nn.Linear(num_features, num_classes)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
# Log metrics
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('train_loss', loss)
self.log('train_acc', acc, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
# Log metrics
preds = torch.argmax(logits, dim=1)
acc = (preds == y).float().mean()
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
def configure_optimizers(self):
# Only optimize the final fully connected layer
optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=self.hparams.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
return {
'optimizer': optimizer,
'lr_scheduler': scheduler
}
class FlowerDataModule(pl.LightningDataModule):
def __init__(self, data_dir='./flowers', batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage=None):
# Data transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load dataset
dataset = ImageFolder(root=self.data_dir, transform=transform)
# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
self.train_dataset, self.val_dataset = random_split(
dataset, [train_size, val_size]
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=os.cpu_count()
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=os.cpu_count()
)
# Assuming we have a flower dataset with 5 classes
model = TransferLearningModel(num_classes=5)
data_module = FlowerDataModule(data_dir='./flowers', batch_size=16)
# Create trainer with various callbacks
trainer = pl.Trainer(
max_epochs=10,
accelerator='auto',
callbacks=[
ModelCheckpoint(monitor='val_acc', mode='max'),
EarlyStopping(monitor='val_loss', patience=3)
],
logger=TensorBoardLogger("logs", name="flower_classifier")
)
# Train and validate
trainer.fit(model, data_module)
Common Pitfalls and Tips
-
Accessing Parameters: Use
self.hparams
for accessing model parameters you want to track. -
Moving Tensors to GPU: Lightning handles device movement, so you don't need to call
.to(device)
. -
Batch Processing: Always process the entire batch in the training, validation, and test steps.
-
Loss Return Values: Always return the loss from
training_step
for backpropagation. -
Debugging: Use
trainer = pl.Trainer(fast_dev_run=True)
to quickly debug your code. -
LR Finders: Lightning provides utilities like learning rate finders:
# Find the optimal learning rate
trainer = pl.Trainer(auto_lr_find=True)
trainer.tune(model, data_module)
Summary
PyTorch Lightning is an incredibly powerful framework that helps you write clean, organized, and reproducible deep learning code. By separating research code from engineering code, Lightning allows you to focus on the aspects of deep learning that matter most to your project.
Key takeaways:
- The
LightningModule
organizes your PyTorch code into specific methods - Lightning handles training loops, validation, and hardware decisions for you
- You get built-in best practices like logging, checkpointing, and early stopping
- Your code remains flexible and can easily scale to multiple GPUs or TPUs
Additional Resources
Exercises
- Convert a simple PyTorch model of your choice to Lightning format.
- Add image augmentation to the FlowerDataModule using Lightning transforms.
- Implement a GAN (Generative Adversarial Network) using Lightning's multi-optimizer support.
- Use Lightning's profiling capabilities to identify bottlenecks in your training pipeline.
- Create a Lightning model that uses mixed precision training to speed up your model training.
Happy Lightning!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)