Skip to main content

PyTorch Image Processing

Introduction

Image processing is a fundamental aspect of computer vision applications. PyTorch provides powerful tools for loading, manipulating, and transforming images to prepare them for deep learning models. This tutorial will guide you through the essential image processing techniques in PyTorch, from basic operations to more advanced transformations.

Whether you're building an image classifier, object detector, or any other vision model, understanding how to properly process your image data is crucial for achieving good results.

Getting Started with PyTorch Image Processing

Required Libraries

Before we begin, let's make sure we have all the necessary libraries installed:

bash
pip install torch torchvision pillow matplotlib

Let's import the libraries we'll need:

python
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

Loading Images in PyTorch

PyTorch doesn't directly load images; instead, it works with the Python Imaging Library (PIL) to handle image files. Let's see how to load an image and convert it to a PyTorch tensor:

python
# Load an image using PIL
img = Image.open('sample_image.jpg')

# Convert PIL image to PyTorch tensor
transform = transforms.ToTensor()
img_tensor = transform(img)

print(f"Image tensor shape: {img_tensor.shape}")
print(f"Tensor data type: {img_tensor.dtype}")

Output:

Image tensor shape: torch.Size([3, 224, 224])
Tensor data type: torch.float32

The tensor shape is [C, H, W] (Channels, Height, Width) which is the standard format for PyTorch. For colored images, we typically have 3 channels (RGB).

Basic Image Transformations

PyTorch's torchvision.transforms module provides numerous tools for image preprocessing. Let's explore some common transformations:

Resizing Images

python
# Define a transform to resize images
resize_transform = transforms.Resize((224, 224))
resized_img = resize_transform(img)

# Convert to tensor for visualization
resized_tensor = transforms.ToTensor()(resized_img)

print(f"Original image size: {img.size}")
print(f"Resized image size: {resized_img.size}")

Output:

Original image size: (640, 480)
Resized image size: (224, 224)

Cropping Images

python
# Center crop the image
center_crop = transforms.CenterCrop((150, 150))
cropped_img = center_crop(img)

# Random crop
random_crop = transforms.RandomCrop((100, 100))
random_cropped_img = random_crop(img)

print(f"Center cropped size: {cropped_img.size}")
print(f"Random cropped size: {random_cropped_img.size}")

Output:

Center cropped size: (150, 150)
Random cropped size: (100, 100)

Color and Intensity Transformations

python
# Convert to grayscale
grayscale = transforms.Grayscale()
gray_img = grayscale(img)

# Adjust brightness, contrast, saturation and hue
color_jitter = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
jittered_img = color_jitter(img)

Creating a Transformation Pipeline

One of the most powerful features of torchvision.transforms is the ability to compose multiple transformations together:

python
# Create a transformation pipeline
transform_pipeline = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Apply the pipeline to the image
transformed_img_tensor = transform_pipeline(img)

print(f"Transformed tensor shape: {transformed_img_tensor.shape}")
print(f"Value range after normalization: [{transformed_img_tensor.min():.2f}, {transformed_img_tensor.max():.2f}]")

Output:

Transformed tensor shape: torch.Size([3, 224, 224])
Value range after normalization: [-2.12, 2.64]

Visualizing Transformed Images

Let's create a helper function to visualize our transformed images:

python
def show_image(tensor, title=None):
"""
Displays a PyTorch tensor as an image
"""
# Convert tensor to numpy array
if tensor.requires_grad:
tensor = tensor.detach()

# Move the channel dimension to the end for matplotlib
if tensor.dim() == 3 and tensor.shape[0] in (1, 3):
tensor = tensor.permute(1, 2, 0)

# For grayscale images
if tensor.shape[-1] == 1:
tensor = tensor.squeeze(-1)

# Convert to numpy and display
img_np = tensor.numpy()

# If normalized, scale back for visualization
if img_np.max() <= 1.0 and img_np.min() >= -1.0:
img_np = np.clip(img_np, 0, 1)

plt.figure(figsize=(10, 10))
plt.imshow(img_np)
if title:
plt.title(title)
plt.axis('off')
plt.show()

# Example usage:
# Load and transform an image
img = Image.open('sample_image.jpg')
tensor = transforms.ToTensor()(img)
show_image(tensor, "Original Image")

# Apply a transformation and show
flipped = transforms.RandomHorizontalFlip(p=1.0)(img) # p=1.0 ensures flipping
flipped_tensor = transforms.ToTensor()(flipped)
show_image(flipped_tensor, "Horizontally Flipped Image")

Data Augmentation for Training

Data augmentation is a technique used to artificially expand your training dataset by creating modified versions of the original images. This helps improve model generalization.

python
# Define an augmentation pipeline for training
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Simpler pipeline for validation/testing
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Loading Images with DataLoader

PyTorch's DataLoader and Dataset classes make it easy to load and process batches of images:

python
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Create datasets
train_dataset = ImageFolder(root='./data/train', transform=train_transforms)
val_dataset = ImageFolder(root='./data/val', transform=val_transforms)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Example of iterating through batches
for images, labels in train_loader:
print(f"Batch shape: {images.shape}")
print(f"Labels: {labels}")
break # Just display the first batch

Output:

Batch shape: torch.Size([32, 3, 224, 224])
Labels: tensor([0, 2, 1, 0, 3, 1, 2, 0, ...])

Custom Image Transformations

Sometimes you might need custom transformations that aren't provided by PyTorch. You can create your own by subclassing torch.nn.Module:

python
class GaussianNoise(torch.nn.Module):
"""Add Gaussian noise to tensor"""
def __init__(self, mean=0., std=1.):
super().__init__()
self.mean = mean
self.std = std

def forward(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean

def __repr__(self):
return self.__class__.__name__ + f'(mean={self.mean}, std={self.std})'

# Usage
custom_transform = transforms.Compose([
transforms.ToTensor(),
GaussianNoise(0, 0.1)
])

noisy_img_tensor = custom_transform(img)

Real-world Example: Image Classification Preprocessing

Let's put everything together in a real-world image classification scenario:

python
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

# Define data transforms
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}

# Create datasets and dataloaders
image_datasets = {
'train': ImageFolder('./data/train', data_transforms['train']),
'val': ImageFolder('./data/val', data_transforms['val'])
}

dataloaders = {
'train': DataLoader(image_datasets['train'], batch_size=16, shuffle=True),
'val': DataLoader(image_datasets['val'], batch_size=16)
}

# Load a pretrained model
model = models.resnet18(pretrained=True)
num_classes = len(image_datasets['train'].classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Train the model (simplified example)
def train_model(model, criterion, optimizer, dataloaders, num_epochs=5):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
print('-' * 10)

# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()

running_loss = 0.0
running_corrects = 0

# Iterate over data
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)

# Zero the parameter gradients
optimizer.zero_grad()

# Forward
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

# Backward + optimize only in training phase
if phase == 'train':
loss.backward()
optimizer.step()

# Statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)

epoch_loss = running_loss / len(image_datasets[phase])
epoch_acc = running_corrects.double() / len(image_datasets[phase])

print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

return model

# Example usage
# model = train_model(model, criterion, optimizer, dataloaders)

Advanced Image Transforms

PyTorch also provides more advanced transformations for complex use cases:

Elastic Transformations

python
from torchvision.transforms.functional import affine

class ElasticTransform(torch.nn.Module):
def __init__(self, alpha=50, sigma=5):
super().__init__()
self.alpha = alpha
self.sigma = sigma

def forward(self, img):
img_tensor = transforms.ToTensor()(img)
h, w = img_tensor.shape[1:]

# Create random displacement fields
dx = torch.randn((h, w)) * self.alpha
dy = torch.randn((h, w)) * self.alpha

# Convert back to PIL for easier transformation
img_np = img_tensor.permute(1, 2, 0).numpy()

# Apply displacement field (simplified implementation)
x, y = np.meshgrid(np.arange(w), np.arange(h))
indices_x = np.reshape(x+dx.numpy(), (-1, 1))
indices_y = np.reshape(y+dy.numpy(), (-1, 1))

# We'll use a simple affine transformation as approximation
# for educational purposes
return transforms.functional.affine(
img,
angle=0,
translate=[dx.mean().item()*0.5, dy.mean().item()*0.5],
scale=1.0,
shear=[0, 0]
)

Perspective Transform

python
# Example of perspective transform
perspective_transform = transforms.RandomPerspective(distortion_scale=0.5, p=1.0)
perspective_img = perspective_transform(img)

perspective_tensor = transforms.ToTensor()(perspective_img)
# show_image(perspective_tensor, "Perspective Transformed")

Image Processing for Specific Computer Vision Tasks

Different computer vision tasks might require specific preprocessing steps:

Object Detection

python
# Object Detection preprocessing often keeps aspect ratio intact
object_detection_transforms = transforms.Compose([
transforms.Resize(800), # Resize the smaller edge to 800 pixels
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

Semantic Segmentation

python
# For segmentation, we need to transform both the image and its mask
def transform_segmentation_sample(image, mask):
# Define the seed for reproducible transformations
seed = torch.randint(0, 2**32, (1,)).item()

# Image transforms
torch.manual_seed(seed)
image = transforms.RandomHorizontalFlip(p=0.5)(image)
image = transforms.ColorJitter(brightness=0.2, contrast=0.2)(image)
image = transforms.ToTensor()(image)
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

# Mask transforms (should match spatial transforms of the image)
torch.manual_seed(seed) # Use same seed for consistent transformations
mask = transforms.RandomHorizontalFlip(p=0.5)(mask)
mask = transforms.ToTensor()(mask)

return image, mask

Summary

In this tutorial, we've explored image processing in PyTorch from basic operations to advanced transformations. We've covered:

  1. Loading and converting images: Using PIL and PyTorch tensors
  2. Basic transformations: Resizing, cropping, and color adjustments
  3. Creating transformation pipelines: Combining multiple transforms
  4. Data augmentation: Enhancing training data
  5. Using DataLoaders: Efficiently loading batched data
  6. Custom transformations: Creating your own specialized transforms
  7. Real-world application: Setting up preprocessing for image classification
  8. Advanced transforms: Elastic and perspective transformations
  9. Task-specific processing: Adapting transforms for various computer vision tasks

Image processing is a critical step in any computer vision pipeline. Well-designed preprocessing can significantly improve your model's performance and robustness.

Additional Resources

Exercises

  1. Basic Image Transformation: Load an image and apply a sequence of 3 different transformations. Visualize the original and transformed images side by side.

  2. Custom Transform: Create a custom transform that adds a colored border around an image.

  3. Augmentation Experiment: Create an augmentation pipeline and apply it to the same image 5 times. Observe how the random transformations create different variations.

  4. Batch Processing: Load a small dataset of images and process them in batches using a DataLoader. Calculate and print the mean and standard deviation of the dataset.

  5. Advanced Challenge: Implement a transformation that simulates rain or snow in images for data augmentation.

Happy image processing with PyTorch!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)