PyTorch Model Evaluation
When you've trained a neural network in PyTorch, you need to know how well it performs. Model evaluation is a crucial step that helps you understand if your model is learning correctly and how it might perform on unseen data. In this guide, we'll explore various techniques to evaluate PyTorch models effectively.
Introduction to Model Evaluation
Model evaluation is the process of assessing how well your trained neural network performs on data it hasn't seen during training. This helps you:
- Measure the model's accuracy and other performance metrics
- Detect overfitting or underfitting
- Compare different models to choose the best one
- Make informed decisions about model improvements
Let's dive into how to evaluate PyTorch models step by step.
Setting Up the Evaluation Environment
First, let's import the necessary PyTorch libraries:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
Preparing Your Model for Evaluation
Before evaluation, you need to switch your model to evaluation mode using .eval()
. This is important because it disables certain layers like Dropout and affects BatchNorm behavior:
# Assuming 'model' is your trained PyTorch model
model.eval()
# Disable gradient calculations for evaluation
with torch.no_grad():
# Your evaluation code will go here
pass
Basic Evaluation Metrics
Accuracy
Accuracy is the most basic and intuitive metric - the percentage of correct predictions:
def calculate_accuracy(model, data_loader, device):
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in data_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
accuracy = 100 * correct / total
return accuracy
# Example usage
test_loader = DataLoader(test_dataset, batch_size=64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
accuracy = calculate_accuracy(model, test_loader, device)
print(f"Test Accuracy: {accuracy:.2f}%")
Output:
Test Accuracy: 92.45%
Loss
Computing the loss on your validation or test set gives you insight into how well your model is performing:
def calculate_loss(model, data_loader, criterion, device):
total_loss = 0
total_samples = 0
with torch.no_grad():
for inputs, targets in data_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item() * inputs.size(0)
total_samples += inputs.size(0)
average_loss = total_loss / total_samples
return average_loss
# Example usage
criterion = nn.CrossEntropyLoss()
test_loss = calculate_loss(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}")
Output:
Test Loss: 0.2134
Advanced Evaluation Techniques
Confusion Matrix
A confusion matrix gives you a detailed breakdown of correct and incorrect predictions for each class:
def get_predictions(model, data_loader, device):
all_predictions = []
all_targets = []
with torch.no_grad():
for inputs, targets in data_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predictions = torch.max(outputs, 1)
all_predictions.extend(predictions.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
return np.array(all_predictions), np.array(all_targets)
# Get predictions
y_pred, y_true = get_predictions(model, test_loader, device)
# Generate confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plot confusion matrix
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
classes = range(10) # For a 10-class problem like MNIST
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.show()
# Print classification report
print(classification_report(y_true, y_pred))
Precision, Recall, and F1-Score
For imbalanced datasets, accuracy alone can be misleading. Here's how to compute precision, recall, and F1-score:
from sklearn.metrics import precision_recall_fscore_support
# Calculate precision, recall, and F1 score
precision, recall, f1, _ = precision_recall_fscore_support(
y_true, y_pred, average='weighted')
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
Output:
Precision: 0.9267
Recall: 0.9245
F1 Score: 0.9251
Evaluating with a Validation Set
During training, it's best practice to use a validation set to monitor performance:
def train_with_validation(model, train_loader, val_loader, criterion, optimizer,
device, epochs=10):
train_losses = []
val_losses = []
train_accs = []
val_accs = []
for epoch in range(epochs):
# Training phase
model.train()
running_loss = 0.0
correct = 0
total = 0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# Track statistics
running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
epoch_loss = running_loss / len(train_loader.dataset)
epoch_acc = 100 * correct / total
train_losses.append(epoch_loss)
train_accs.append(epoch_acc)
# Validation phase
val_loss = calculate_loss(model, val_loader, criterion, device)
val_acc = calculate_accuracy(model, val_loader, device)
val_losses.append(val_loss)
val_accs.append(val_acc)
print(f"Epoch {epoch+1}/{epochs}")
print(f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.2f}%")
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
print("-" * 50)
# Plot training and validation metrics
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curves')
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy %')
plt.legend()
plt.title('Accuracy Curves')
plt.tight_layout()
plt.show()
return train_losses, val_losses, train_accs, val_accs
Practical Example: Evaluating an MNIST Classifier
Let's put everything together with a complete example using the MNIST dataset:
# Create model, datasets, and loaders
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28*28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Prepare the data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Load and split datasets
full_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)
# Initialize model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Train model with validation
train_losses, val_losses, train_accs, val_accs = train_with_validation(
model, train_loader, val_loader, criterion, optimizer, device, epochs=5)
# Final evaluation on test set
test_loss = calculate_loss(model, test_loader, criterion, device)
test_acc = calculate_accuracy(model, test_loader, device)
print(f"Final Test Loss: {test_loss:.4f}")
print(f"Final Test Accuracy: {test_acc:.2f}%")
# Get confusion matrix and classification report
y_pred, y_true = get_predictions(model, test_loader, device)
print("\nClassification Report:")
print(classification_report(y_true, y_pred))
Using Evaluation to Detect Overfitting
One of the most important uses of evaluation is to detect overfitting, where your model performs well on training data but poorly on new data:
def plot_overfitting_analysis(train_losses, val_losses, train_accs, val_accs):
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.axvspan(
np.argmin(val_losses), len(val_losses)-1,
alpha=0.3, color='red', label='Potential Overfitting'
)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curves (Potential Overfitting Detection)')
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.axvspan(
np.argmax(val_accs), len(val_accs)-1,
alpha=0.3, color='red', label='Potential Overfitting'
)
plt.xlabel('Epochs')
plt.ylabel('Accuracy %')
plt.legend()
plt.title('Accuracy Curves (Potential Overfitting Detection)')
plt.tight_layout()
plt.show()
best_epoch = np.argmin(val_losses)
print(f"Best model found at epoch {best_epoch + 1}")
print(f"Best validation loss: {val_losses[best_epoch]:.4f}")
print(f"Validation accuracy at best epoch: {val_accs[best_epoch]:.2f}%")
Saving and Loading the Best Model
After evaluation, you'll want to save your best model:
def save_best_model(model, val_losses, model_path='best_model.pth'):
best_epoch = np.argmin(val_losses)
torch.save(model.state_dict(), model_path)
print(f"Best model saved at epoch {best_epoch + 1} to {model_path}")
# Save the model
save_best_model(model, val_losses, 'mnist_classifier.pth')
# Load the model later
loaded_model = SimpleNN().to(device)
loaded_model.load_state_dict(torch.load('mnist_classifier.pth'))
loaded_model.eval() # Set to evaluation mode
Cross-Validation in PyTorch
For more robust evaluation, especially with smaller datasets, you might use k-fold cross-validation:
from sklearn.model_selection import KFold
def cross_validate(dataset, model_class, criterion, k_folds=5,
batch_size=64, epochs=3, learning_rate=0.001):
kfold = KFold(n_splits=k_folds, shuffle=True)
fold_results = []
for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
print(f"FOLD {fold+1}/{k_folds}")
print('-' * 30)
# Sample elements randomly from a given list of ids
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_subsampler)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_subsampler)
# Initialize model, criterion, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model_class().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train for the specified number of epochs
for epoch in range(epochs):
# Training
model.train()
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# Evaluate on validation set
val_acc = calculate_accuracy(model, val_loader, device)
print(f"Validation Accuracy: {val_acc:.2f}%")
fold_results.append(val_acc)
# Print average results
print(f"\nK-Fold Cross Validation Results for {k_folds} Folds:")
print(f"Mean Accuracy: {np.mean(fold_results):.2f}%")
print(f"Standard Deviation: {np.std(fold_results):.2f}%")
return fold_results
Evaluating for Different Tasks
Regression Evaluation
If you're working on a regression task rather than classification:
def evaluate_regression_model(model, test_loader, device):
model.eval()
predictions = []
actual_values = []
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
predictions.extend(outputs.cpu().numpy())
actual_values.extend(targets.cpu().numpy())
predictions = np.array(predictions)
actual_values = np.array(actual_values)
# Calculate MSE
mse = np.mean((predictions - actual_values) ** 2)
# Calculate RMSE
rmse = np.sqrt(mse)
# Calculate MAE
mae = np.mean(np.abs(predictions - actual_values))
# Calculate R²
ss_total = np.sum((actual_values - np.mean(actual_values)) ** 2)
ss_residual = np.sum((actual_values - predictions) ** 2)
r_squared = 1 - (ss_residual / ss_total)
print(f"Mean Squared Error (MSE): {mse:.4f}")
print(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
print(f"Mean Absolute Error (MAE): {mae:.4f}")
print(f"R² Score: {r_squared:.4f}")
return mse, rmse, mae, r_squared
Summary
Model evaluation is a critical step in the machine learning workflow. In this guide, we've covered:
- Basic evaluation with accuracy and loss
- Advanced metrics like confusion matrices, precision, and recall
- Validation techniques to prevent overfitting
- Saving and loading the best model
- Cross-validation for more robust evaluation
- Specialized metrics for different tasks
Remember that proper evaluation helps you:
- Understand how well your model is performing
- Identify areas for improvement
- Make informed decisions about model selection
- Gain confidence in your model before deployment
Exercises
- Exercise 1: Train a simple CNN on CIFAR-10 and evaluate it using accuracy, precision, recall, and F1-score.
- Exercise 2: Apply k-fold cross-validation to a custom dataset and plot the performance across different folds.
- Exercise 3: Create a function that performs early stopping during training based on validation loss.
- Exercise 4: Build a regression model for a time series dataset and evaluate it using appropriate metrics.
- Exercise 5: Compare two different model architectures using the evaluation techniques learned in this guide.
Additional Resources
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)