Skip to main content

PyTorch TorchVision

Introduction to TorchVision

TorchVision is a package in the PyTorch ecosystem specifically designed for computer vision tasks. It provides utilities for efficient image and video transformations, popular datasets, and pre-trained model architectures that have been instrumental in advancing the field of computer vision. If you're stepping into the world of computer vision with PyTorch, TorchVision is your trusted companion that saves you from reinventing the wheel.

This library consists of three main components:

  • Datasets: Easy access to common benchmark datasets
  • Models: Pre-trained neural network architectures
  • Transformations: Common image preprocessing and augmentation techniques

Setting Up TorchVision

Before we dive into using TorchVision, let's make sure it's properly installed.

python
# Install TorchVision if you don't have it yet
# !pip install torchvision

# Import essential libraries
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np

# Print the versions to ensure compatibility
print(f"PyTorch version: {torch.__version__}")
print(f"TorchVision version: {torchvision.__version__}")

Output:

PyTorch version: 2.0.1
TorchVision version: 0.15.2

TorchVision Datasets

One of the most valuable aspects of TorchVision is its easy access to popular datasets. Let's explore how to load and use the MNIST dataset, a collection of handwritten digits.

python
from torchvision import datasets, transforms

# Define a transformation pipeline
transform = transforms.Compose([
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize((0.1307,), (0.3081,)) # Normalize with mean and std of MNIST
])

# Download and load the training dataset
train_dataset = datasets.MNIST(root='./data',
train=True,
download=True,
transform=transform)

# Download and load the testing dataset
test_dataset = datasets.MNIST(root='./data',
train=False,
download=True,
transform=transform)

# Create data loaders for batch processing
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Let's see what we have
print(f"Training dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Output:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
Training dataset size: 60000
Test dataset size: 10000

Visualizing Dataset Images

Let's visualize some of the images from the MNIST dataset:

python
# Function to display images from our dataset
def show_images(loader):
# Get some random training images
dataiter = iter(loader)
images, labels = next(dataiter)

# Create a grid from batch
img_grid = torchvision.utils.make_grid(images[:25], nrow=5)

# Show images
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(img_grid, (1, 2, 0)), cmap='gray')
plt.axis('off')
plt.title('Sample Images from Dataset')
plt.show()

# Print labels
print('Labels:', ' '.join(f'{labels[j]}' for j in range(25)))

# Show some training images
show_images(train_loader)

TorchVision Transforms

Transforms are a crucial part of image preprocessing. TorchVision offers various transforms that can be applied to images for normalization, augmentation, and other preprocessing needs.

Let's explore some common transforms:

python
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO

# Download a sample image
response = requests.get('https://upload.wikimedia.org/wikipedia/commons/e/e3/Orangutan_01.jpg')
img = Image.open(BytesIO(response.content))

# Define various transformations
transform_operations = {
"Original": transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
]),
"Horizontal Flip": transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor()
]),
"Color Jitter": transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
transforms.ToTensor()
]),
"Rotation": transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomRotation(30),
transforms.ToTensor()
]),
"Crop": transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
]),
}

# Apply transformations and display results
fig = plt.figure(figsize=(15, 12))

for i, (name, transform) in enumerate(transform_operations.items(), 1):
transformed_img = transform(img)
ax = fig.add_subplot(1, 5, i)
ax.set_title(name)
ax.imshow(transformed_img.permute(1, 2, 0))
ax.axis('off')

plt.tight_layout()
plt.show()

TorchVision Pre-trained Models

TorchVision provides access to popular pre-trained models that can be used for various tasks like image classification, object detection, and semantic segmentation.

Using a Pre-trained Model for Image Classification

Let's use a pre-trained ResNet model to classify our sample image:

python
from torchvision import models
import torch.nn.functional as F

# Load a pre-trained ResNet model
model = models.resnet18(weights='IMAGENET1K_V1')
model.eval() # Set the model to evaluation mode

# Define the preprocessing pipeline for ResNet
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Apply preprocessing to our sample image
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0) # Add batch dimension

# Make a prediction
with torch.no_grad():
output = model(input_batch)

# The output has unnormalized scores. To get probabilities, run a softmax on it.
probabilities = F.softmax(output[0], dim=0)

# Download ImageNet labels
import json
import urllib.request

# Get labels for ImageNet
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
urllib.request.urlretrieve(url, "imagenet_classes.txt")

with open("imagenet_classes.txt") as f:
categories = [s.strip() for s in f.readlines()]

# Show top 5 predictions
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
print(f"{categories[top5_catid[i]]} ({top5_prob[i].item()*100:.2f}%)")

Output:

orangutan, orang, orangutang, Pongo pygmaeus (97.47%)
chimpanzee, chimp, Pan troglodytes (2.51%)
gibbon, Hylobates lar (0.01%)
gorilla, Gorilla gorilla (0.01%)
langur (0.00%)

Real-world Application: Transfer Learning

One of the most practical applications of TorchVision models is transfer learning - reusing pre-trained models on new tasks. Let's see how we can fine-tune a pre-trained model for our own image classification task:

python
import torch.nn as nn
import torch.optim as optim

# Load a pre-trained model
model = models.resnet18(weights='IMAGENET1K_V1')

# Freeze all layers except the final classification layer
for param in model.parameters():
param.requires_grad = False

# Replace the final fully connected layer
num_ftrs = model.fc.in_features
num_classes = 10 # Example: 10 different classes in our dataset
model.fc = nn.Linear(num_ftrs, num_classes)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

# Function to train the model
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
for epoch in range(num_epochs):
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
# Zero the parameter gradients
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()

# Print statistics
running_loss += loss.item()
if i % 100 == 99: # Print every 100 mini-batches
print(f'[Epoch {epoch + 1}, Batch {i + 1}] Loss: {running_loss / 100:.3f}')
running_loss = 0.0

print('Finished Training')

# Note: In a real-world scenario, you would now call train_model with your actual dataset
# train_model(model, train_loader, criterion, optimizer, num_epochs=5)

Object Detection with TorchVision

TorchVision also provides pre-trained models for object detection tasks. Here's how you can use a pre-trained Faster R-CNN model:

python
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms import functional as F

# Load a pre-trained Faster R-CNN model
model = fasterrcnn_resnet50_fpn(weights='DEFAULT')
model.eval()

# Define COCO dataset class names
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

# Function to make prediction and visualize
def detect_objects(image_path_or_url):
# If it's a URL, download the image
if image_path_or_url.startswith('http'):
response = requests.get(image_path_or_url)
img = Image.open(BytesIO(response.content))
else:
img = Image.open(image_path_or_url)

# Convert PIL image to tensor
img_tensor = F.to_tensor(img)

# Make prediction
with torch.no_grad():
prediction = model([img_tensor])

# Get results
boxes = prediction[0]['boxes'].detach().numpy()
scores = prediction[0]['scores'].detach().numpy()
labels = prediction[0]['labels'].detach().numpy()

# Filter results (keep predictions with score > 0.7)
mask = scores > 0.7
boxes = boxes[mask]
labels = labels[mask]
scores = scores[mask]

# Visualize result
fig, ax = plt.subplots(1, figsize=(12, 9))
ax.imshow(img)

for box, label, score in zip(boxes, labels, scores):
x1, y1, x2, y2 = box
rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, color='red', linewidth=2)
ax.add_patch(rect)
ax.text(x1, y1-10, f"{COCO_INSTANCE_CATEGORY_NAMES[label]}: {score:.2f}",
color='white', fontsize=12, backgroundcolor='red')

ax.axis('off')
plt.tight_layout()
plt.show()

# Example use:
# detect_objects('https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg')

Custom DataLoaders with TorchVision

For real-world projects, you often need to work with your own datasets. TorchVision makes this easy with custom DataLoaders:

python
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
"""
Custom dataset for image classification
Args:
img_dir (str): Directory with all images organized in class folders
transform (callable, optional): Optional transform to be applied on an image
"""
self.img_dir = img_dir
self.transform = transform
self.classes = [d for d in os.listdir(img_dir) if os.path.isdir(os.path.join(img_dir, d))]
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

self.images = []
self.labels = []

# Collect all image paths and their corresponding labels
for cls_name in self.classes:
class_path = os.path.join(img_dir, cls_name)
class_idx = self.class_to_idx[cls_name]

for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
if os.path.isfile(img_path):
self.images.append(img_path)
self.labels.append(class_idx)

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]

if self.transform:
image = self.transform(image)

return image, label

# Example usage:
# train_transform = transforms.Compose([
# transforms.RandomResizedCrop(224),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])
#
# dataset = CustomImageDataset(img_dir='./dataset/train', transform=train_transform)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

Summary

In this tutorial, we explored TorchVision, an essential library for computer vision tasks in PyTorch. We covered:

  1. Datasets: How to load built-in datasets like MNIST and create custom datasets
  2. Transforms: Various image preprocessing and augmentation techniques
  3. Models: Using pre-trained models for image classification and object detection
  4. Transfer Learning: Adapting pre-trained models for new tasks

TorchVision simplifies many complex tasks in computer vision and provides a solid foundation for your deep learning projects. Whether you're working on image classification, object detection, segmentation, or other vision tasks, TorchVision offers tools that can accelerate your development process.

Additional Resources

Exercises

  1. Basic Exercise: Load the CIFAR-10 dataset using TorchVision and display a batch of images with their labels.
  2. Intermediate Exercise: Implement a transfer learning approach to classify images from the Flowers dataset using a pre-trained ResNet model.
  3. Advanced Exercise: Create a simple object detection application that uses a webcam to detect objects in real-time using a pre-trained Faster R-CNN model from TorchVision.
  4. Challenge: Create a custom dataset loader for your own image dataset and train a model to classify these images.

By mastering TorchVision, you'll have a powerful set of tools for working with images and developing state-of-the-art computer vision applications!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)