Skip to main content

PyTorch Data Augmentation

Data augmentation is a powerful technique that artificially expands your training dataset by creating modified versions of existing data. In deep learning, especially for computer vision tasks, data augmentation helps your models generalize better by exposing them to various transformations of the same underlying data.

Why Use Data Augmentation?

  • Prevents overfitting: By introducing variations, your model learns the essential features instead of memorizing the training data
  • Improves generalization: Models become more robust to different perspectives, lighting conditions, and positions
  • Mitigates data scarcity: Creates more training examples when collecting additional data is expensive or impractical

Basic Data Augmentation in PyTorch

PyTorch provides data augmentation functionality through the torchvision.transforms module. Let's start with some basic image transformations:

python
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Load a sample image
img = Image.open('sample_dog.jpg')

# Define a basic transformation
basic_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 50% chance of flipping
transforms.RandomRotation(10), # Rotate by up to 10 degrees
transforms.ColorJitter(brightness=0.2, contrast=0.2) # Adjust brightness and contrast
])

# Apply transformation
augmented_img = basic_transform(img)

# Display original and augmented images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(img)
ax1.set_title('Original Image')
ax1.axis('off')

ax2.imshow(augmented_img)
ax2.set_title('Augmented Image')
ax2.axis('off')

plt.show()

Common Transformation Types

PyTorch offers a wide variety of transformations that you can combine to create sophisticated augmentation pipelines:

Geometric Transformations

python
geometric_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(), # Flip horizontally
transforms.RandomVerticalFlip(), # Flip vertically
transforms.RandomRotation(degrees=15), # Random rotation
transforms.RandomAffine(
degrees=10, # Rotation
translate=(0.1, 0.1), # Translation (% of total width/height)
scale=(0.9, 1.1), # Scale
shear=5 # Shear angle in degrees
),
transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # Perspective change
transforms.Resize((224, 224)) # Resize to standard dimensions
])

Color and Intensity Transformations

python
color_transforms = transforms.Compose([
transforms.ColorJitter(
brightness=0.2, # Brightness adjustment
contrast=0.2, # Contrast adjustment
saturation=0.2, # Saturation adjustment
hue=0.1 # Hue adjustment
),
transforms.RandomGrayscale(p=0.1), # 10% chance to convert to grayscale
transforms.RandomAutocontrast(p=0.2), # Auto adjust contrast
transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.3), # Sharpen image
])

Advanced Augmentations

python
advanced_transforms = transforms.Compose([
transforms.RandomErasing(p=0.5, scale=(0.02, 0.1)), # Randomly erase rectangular areas
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), # Apply Gaussian blur
transforms.RandomInvert(p=0.1), # Invert colors
transforms.RandomPosterize(bits=2, p=0.2), # Reduce number of bits per color channel
transforms.RandomSolarize(threshold=128, p=0.2), # Invert all pixel values above threshold
])

Implementing Data Augmentation in Training Pipeline

Now let's integrate data augmentation into a complete training pipeline using PyTorch's Dataset and DataLoader classes:

python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Training transformations with data augmentation
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Validation transformations (no augmentation needed)
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets with the corresponding transformations
train_dataset = datasets.ImageFolder('data/train', transform=train_transforms)
val_dataset = datasets.ImageFolder('data/val', transform=val_transforms)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Initialize model, loss function, and optimizer
model = models.resnet18(pretrained=True)
num_classes = len(train_dataset.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training loop (simplified)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def train_one_epoch(model, train_loader, criterion, optimizer, device):
model.train()
running_loss = 0.0

for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()

return running_loss / len(train_loader)

# Example of running one epoch
epoch_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
print(f"Training Loss: {epoch_loss:.4f}")

Real-World Example: Custom Data Augmentation

Sometimes you might need custom augmentations not provided by PyTorch. Here's how to create a custom transformation that simulates weather conditions:

python
class AddRain(object):
"""Add rain effect to images."""
def __init__(self, rain_density=0.1, drop_length=20, drop_width=1, drop_color=(200, 200, 200)):
self.rain_density = rain_density
self.drop_length = drop_length
self.drop_width = drop_width
self.drop_color = drop_color

def __call__(self, img):
img_np = np.array(img)
h, w = img_np.shape[:2]

# Create raindrops
num_drops = int((h * w) * self.rain_density / 100)

for i in range(num_drops):
x = np.random.randint(0, w)
y = np.random.randint(0, h)

# Draw a raindrop (line)
for j in range(self.drop_length):
if y + j < h and x + j < w:
img_np[y + j, x + j, :3] = self.drop_color

return Image.fromarray(img_np)

# Usage
custom_transforms = transforms.Compose([
transforms.Resize((224, 224)),
AddRain(rain_density=0.2),
transforms.ToTensor()
])

# Apply to an image
rainy_img = custom_transforms(img)

Using On-the-Fly Data Augmentation

A key advantage of PyTorch's transformation system is that augmentations are applied dynamically during training. Each time a data sample is loaded, it gets a new random transformation:

python
import numpy as np

# Create a simple dataset
class AugmentationDemo(torch.utils.data.Dataset):
def __init__(self, img_path, transform=None, samples=5):
self.img = Image.open(img_path)
self.transform = transform
self.samples = samples

def __len__(self):
return self.samples

def __getitem__(self, idx):
if self.transform:
return self.transform(self.img)
return self.img

# Define transformations
demo_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20),
transforms.ColorJitter(brightness=0.5),
transforms.ToTensor()
])

# Create dataset and dataloader
demo_dataset = AugmentationDemo('sample_image.jpg', transform=demo_transform)
demo_loader = DataLoader(demo_dataset, batch_size=5)

# Display multiple augmentations of the same image
batch = next(iter(demo_loader))

fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, img_tensor in enumerate(batch):
img = img_tensor.permute(1, 2, 0).numpy() # Convert from CxHxW to HxWxC
axes[i].imshow(np.clip(img, 0, 1))
axes[i].set_title(f"Augmentation {i+1}")
axes[i].axis('off')

plt.tight_layout()
plt.show()

Albumentations: A Powerful Alternative

While PyTorch's built-in transformations are excellent, the Albumentations library offers even more augmentation options with a focus on performance:

python
# pip install albumentations
import albumentations as A
import cv2

# Define transformations using Albumentations
albu_transform = A.Compose([
A.RandomCrop(width=256, height=256),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
A.Rotate(limit=20, p=0.7),
A.GaussNoise(p=0.3),
A.Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0, p=0.3),
])

# Load image with OpenCV (Albumentations works with numpy arrays)
img = cv2.imread('sample_image.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB

# Apply transformation
augmented = albu_transform(image=img)
augmented_img = augmented['image']

# To use with PyTorch, you'll need to convert to tensor
to_tensor = transforms.ToTensor()
tensor_img = to_tensor(augmented_img)

Data Augmentation Best Practices

  1. Choose transformations that make sense for your data: If your camera is always upright, vertical flips might not be useful.

  2. Don't overdo it: Extreme augmentations can make the task too difficult for the model.

  3. Use different augmentations for different tasks: Object detection might need different augmentations than classification.

  4. Benchmark with and without augmentation: Always compare performance to verify that augmentation helps.

  5. Validate on non-augmented data: Always evaluate your model on non-transformed validation data.

  6. Consider progressive augmentation: Start with mild augmentations and increase their strength during training.

Summary

Data augmentation is a crucial technique for improving model performance, especially in computer vision tasks. PyTorch provides a rich set of tools for implementing various augmentation strategies through the torchvision.transforms module. By artificially expanding your dataset through these transformations, you can train more robust models that generalize better to unseen data.

In this tutorial, we covered:

  • Basic image transformations in PyTorch
  • Common augmentation techniques and their implementation
  • Integrating data augmentation in the training pipeline
  • Creating custom transformations
  • Best practices for effective data augmentation

Additional Resources

Exercises

  1. Implement a training loop that uses different augmentation strategies and compare their performance.
  2. Create a custom transformation that simulates motion blur.
  3. Experiment with different augmentation parameters and observe their effect on a simple classification task.
  4. Implement an augmentation strategy that gradually increases transformation intensity during training.
  5. Use Albumentations library to create a more complex augmentation pipeline and compare its performance with PyTorch's built-in transforms.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)