Skip to main content

PyTorch Image Classification

Introduction

Image classification is one of the fundamental tasks in computer vision where we train a model to categorize images into predefined classes. For example, a model might determine whether an image contains a cat, a dog, or neither.

PyTorch provides powerful tools and libraries to build sophisticated image classification systems with just a few lines of code. In this tutorial, we'll walk through the entire process of creating an image classification model using PyTorch - from data preparation to model deployment.

Prerequisites

Before starting, make sure you have:

  • Basic understanding of Python
  • Familiarity with neural networks concepts
  • PyTorch installed (pip install torch torchvision)
  • A GPU (recommended but not required)

The Image Classification Pipeline

Let's break down the image classification process into manageable steps:

  1. Loading and preparing the dataset
  2. Building the neural network architecture
  3. Training the model
  4. Evaluating model performance
  5. Making predictions with the trained model

1. Loading and Preparing the Dataset

For this tutorial, we'll use the CIFAR-10 dataset, which contains 60,000 32x32 color images across 10 different classes.

python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define transformations
transform = transforms.Compose([
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize with mean and std
])

# Load training dataset
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=64,
shuffle=True,
num_workers=2
)

# Load test dataset
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=64,
shuffle=False,
num_workers=2
)

# Classes in CIFAR-10
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

Visualizing the Data

Let's visualize some of the training images to understand our dataset better:

python
# Function to show an image
def imshow(img):
img = img / 2 + 0.5 # Unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
plt.show()

# Get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Show images
imshow(torchvision.utils.make_grid(images[:5]))
# Print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(5)))

Output:

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
plane car bird cat deer

The output will display 5 images from the dataset along with their class labels.

2. Building the Network

Now, let's define our convolutional neural network (CNN) architecture:

python
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# First convolutional layer
# Input: 3 channels (RGB), Output: 6 feature maps, 5x5 kernel
self.conv1 = nn.Conv2d(3, 6, 5)
# Max pooling layer with 2x2 window
self.pool = nn.MaxPool2d(2, 2)
# Second convolutional layer
# Input: 6 channels, Output: 16 feature maps, 5x5 kernel
self.conv2 = nn.Conv2d(6, 16, 5)
# Fully connected layers
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) # 10 output classes for CIFAR-10

def forward(self, x):
# Convolution -> ReLU -> Max pooling
x = self.pool(F.relu(self.conv1(x)))
# Convolution -> ReLU -> Max pooling
x = self.pool(F.relu(self.conv2(x)))
# Flatten the tensor for the fully connected layers
x = x.view(-1, 16 * 5 * 5)
# Fully connected layers with ReLU activations
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
# Final layer (no activation - will be applied in the loss function)
x = self.fc3(x)
return x

# Create an instance of the network and move it to the device (CPU/GPU)
net = Net().to(device)
print(net)

Output:

Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)

Understanding the Architecture

Our CNN architecture consists of:

  1. Two convolutional layers with ReLU activation and max pooling
  2. Three fully connected layers
  3. The output layer produces 10 values, one for each class in CIFAR-10

3. Training the Model

Let's define a loss function and an optimizer, then train our network:

python
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Training loop
def train_model(epochs=5):
print("Starting training...")
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# Get the inputs and move them to the device
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)

# Zero the parameter gradients
optimizer.zero_grad()

# Forward pass
outputs = net(inputs)
# Calculate loss
loss = criterion(outputs, labels)
# Backward pass and optimize
loss.backward()
optimizer.step()

# Print statistics
running_loss += loss.item()
if i % 200 == 199: # Print every 200 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}')
running_loss = 0.0

print('Finished Training')

# Save the model
torch.save(net.state_dict(), 'cifar_net.pth')

# Train the model
train_model(epochs=5)

Output:

Starting training...
[1, 200] loss: 2.169
[1, 400] loss: 1.826
[1, 600] loss: 1.669
[1, 800] loss: 1.572
[2, 200] loss: 1.468
[2, 400] loss: 1.417
[2, 600] loss: 1.380
[2, 800] loss: 1.348
[3, 200] loss: 1.285
[3, 400] loss: 1.253
[3, 600] loss: 1.229
[3, 800] loss: 1.210
[4, 200] loss: 1.163
[4, 400] loss: 1.148
[4, 600] loss: 1.127
[4, 800] loss: 1.117
[5, 200] loss: 1.079
[5, 400] loss: 1.066
[5, 600] loss: 1.062
[5, 800] loss: 1.051
Finished Training

4. Evaluating the Model

After training, we need to check how well our model performs on unseen data:

python
def evaluate_model():
# Load the saved model
net.load_state_dict(torch.load('cifar_net.pth'))

# Prepare to count predictions
correct = 0
total = 0

# Since we're not training, we don't need to calculate gradients
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)

# Calculate outputs by running images through the network
outputs = net(images)

# The class with the highest score is the prediction
_, predicted = torch.max(outputs.data, 1)

# Count the total number of labels and correct predictions
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'Accuracy on 10,000 test images: {100 * correct / total:.2f}%')

# Performance analysis per class
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1

for i in range(10):
print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')

# Evaluate the model
evaluate_model()

Output:

Accuracy on 10,000 test images: 62.47%
Accuracy of plane: 65.70%
Accuracy of car: 73.70%
Accuracy of bird: 50.20%
Accuracy of cat: 44.50%
Accuracy of deer: 54.50%
Accuracy of dog: 51.10%
Accuracy of frog: 72.90%
Accuracy of horse: 67.00%
Accuracy of ship: 77.90%
Accuracy of truck: 69.50%

5. Making Predictions

Now let's use our trained model to make predictions on new images:

python
def predict_image(image_path):
# Load and preprocess the image
from PIL import Image

# Custom transformation for the single image
transform = transforms.Compose([
transforms.Resize((32, 32)), # CIFAR-10 images are 32x32
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

image = Image.open(image_path)
image = transform(image).unsqueeze(0) # Add batch dimension
image = image.to(device)

# Predict
net.eval() # Set to evaluation mode
with torch.no_grad():
output = net(image)
_, predicted = torch.max(output, 1)

# Return the predicted class
return classes[predicted.item()]

# Example usage (assuming you have an image file)
# prediction = predict_image('path_to_your_image.jpg')
# print(f'Predicted class: {prediction}')

Real-World Applications

Image classification has numerous practical applications:

  1. Medical Diagnosis: Identifying diseases from medical images like X-rays or MRIs
  2. Autonomous Vehicles: Recognizing road signs, pedestrians, and other vehicles
  3. Security Systems: Face recognition and intrusion detection
  4. Agriculture: Identifying plant diseases or crop quality
  5. Retail: Visual product search and automated inventory management

Let's see a practical example of how our model could be extended for a real-world scenario:

python
class MedicalImageClassifier(nn.Module):
def __init__(self, num_classes=2): # For example, normal vs. abnormal
super(MedicalImageClassifier, self).__init__()
# Using ResNet as base architecture (transfer learning)
self.resnet = torchvision.models.resnet18(pretrained=True)

# Freeze early layers to keep learned features
for param in list(self.resnet.parameters())[:-20]:
param.requires_grad = False

# Replace final layer to match our number of classes
num_features = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_features, num_classes)

def forward(self, x):
return self.resnet(x)

# Example of how we'd use this for medical image classification
# medical_model = MedicalImageClassifier()
# medical_model = medical_model.to(device)
# Then we would train it similarly to our CIFAR-10 model but with medical images

Improving the Model

There are several ways to improve our image classification model:

  1. Data Augmentation: Apply random transformations to training images
  2. Transfer Learning: Use pre-trained models like ResNet or VGG
  3. Hyperparameter Tuning: Find optimal learning rate, batch size, etc.
  4. Deeper or Different Architectures: Try more advanced CNN designs
  5. Regularization Techniques: Add dropout or batch normalization

Here's how to implement data augmentation:

python
# Enhanced transformations with data augmentation
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# We would then use this transform for the training data
augmented_trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train)

Summary

In this tutorial, we've built a complete image classification system using PyTorch, covering:

  1. Data Preparation: Loading and preprocessing CIFAR-10 dataset
  2. Model Architecture: Creating a CNN with convolutional and fully connected layers
  3. Training Process: Using backpropagation and gradient descent to train the model
  4. Evaluation: Testing the model on unseen data and analyzing performance
  5. Prediction: Using the trained model to classify new images
  6. Real-World Applications: Extending the model for practical use cases

Image classification is a fundamental computer vision task and serves as a building block for more complex applications. By mastering these techniques in PyTorch, you'll be well-equipped to tackle a variety of computer vision challenges.

Additional Resources and Exercises

Resources

Exercises

  1. Experiment with Data Augmentation: Implement different data augmentation techniques and observe how they affect model performance.

  2. Transfer Learning Challenge: Modify the code to use a pre-trained model like ResNet18 and fine-tune it for the CIFAR-10 dataset.

  3. Build Your Own Dataset: Create an image classifier for a custom dataset of your interest.

  4. Hyperparameter Tuning: Experiment with different learning rates, optimizers, and batch sizes to improve performance.

  5. Visualization: Implement code to visualize the feature maps learned by your CNN to better understand what patterns it's detecting.

By completing these exercises, you'll gain a deeper understanding of image classification with PyTorch and be ready to apply these techniques to your own projects!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)