Skip to main content

PyTorch Vision Transformers

Introduction

Vision Transformers (ViT) represent a paradigm shift in computer vision. While convolutional neural networks (CNNs) have dominated the field for years, transformers—originally designed for natural language processing tasks—have proven remarkably effective for image analysis.

In this tutorial, we'll explore Vision Transformers using PyTorch, one of the most popular deep learning frameworks. By the end, you'll understand:

  • The architecture of Vision Transformers
  • How to implement a basic ViT model in PyTorch
  • Fine-tuning pre-trained models for your specific tasks
  • Practical applications of Vision Transformers

Understanding Vision Transformers

From NLP to Computer Vision

Transformers were initially designed for text processing with the groundbreaking "Attention is All You Need" paper. The key innovation was the self-attention mechanism, which allows the model to weigh the importance of different parts of the input sequence.

Google's "An Image is Worth 16x16 Words" paper adapted transformers to images by:

  1. Splitting images into fixed-size patches
  2. Flattening these patches into sequences of vectors
  3. Processing them with a standard transformer encoder

The Architecture

A Vision Transformer consists of:

  1. Patch Embedding: Divides the image into patches and creates linear embeddings
  2. Position Embedding: Adds positional information to maintain spatial relationships
  3. Transformer Encoder: Processes the sequence through multiple layers of multi-head self-attention
  4. MLP Head: Generates the final classification or prediction

Implementing a Vision Transformer in PyTorch

Let's implement a simplified Vision Transformer for image classification:

python
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
def __init__(self, image_size, patch_size, in_channels, embed_dim):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2

self.projection = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size,
stride=patch_size)

def forward(self, x):
# x shape: [batch, channels, height, width]
x = self.projection(x) # [batch, embed_dim, grid, grid]
x = x.flatten(2) # [batch, embed_dim, grid*grid]
x = x.transpose(1, 2) # [batch, grid*grid, embed_dim]
return x

class VisionTransformer(nn.Module):
def __init__(self, image_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0,
num_classes=1000):
super().__init__()
self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
num_patches = self.patch_embed.num_patches

# Class token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

# Position embeddings
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=int(embed_dim * mlp_ratio),
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

# MLP Head
self.mlp_head = nn.Linear(embed_dim, num_classes)

def forward(self, x):
# Patch embedding
x = self.patch_embed(x)

# Add class token
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)

# Add position embeddings
x = x + self.pos_embed

# Apply transformer
x = self.transformer(x)

# Use class token for classification
x = x[:, 0]

# MLP head
x = self.mlp_head(x)

return x

Let's test our implementation with a sample image:

python
# Create a simple ViT model
model = VisionTransformer(
image_size=224,
patch_size=16,
in_channels=3,
embed_dim=384,
depth=6,
num_heads=6,
mlp_ratio=4.0,
num_classes=1000
)

# Create a random input tensor (batch_size=1, channels=3, height=224, width=224)
x = torch.randn(1, 3, 224, 224)

# Forward pass
output = model(x)
print(f"Output shape: {output.shape}") # Should be [1, 1000]

Output:

Output shape: torch.Size([1, 1000])

Using Pre-trained Vision Transformers

Building and training a ViT from scratch requires substantial computational resources. Fortunately, PyTorch provides pre-trained models through libraries like timm (PyTorch Image Models) and torchvision.

Using timm

python
import timm
import torch
from PIL import Image
import requests
from io import BytesIO
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

# Load a pre-trained ViT model
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()

# Get the appropriate transform
config = resolve_data_config({}, model=model)
transform = create_transform(**config)

# Download and transform an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert('RGB')
img_tensor = transform(img).unsqueeze(0) # Add batch dimension

# Inference
with torch.no_grad():
output = model(img_tensor)

# Get top-5 predictions
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)

# Print predictions
for i, (prob, catid) in enumerate(zip(top5_prob, top5_catid)):
print(f"{i+1}: {100 * prob.item():.2f}% - {catid}")

Example output:

1: 76.82% - 281 (tabby, tabby cat)
2: 10.38% - 282 (tiger cat)
3: 1.64% - 283 (Persian cat)
4: 0.52% - 478 (carton)
5: 0.33% - 287 (lynx, catamount)

Using torchvision

With recent versions of torchvision, you can also access pre-trained ViT models:

python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO

# Load a pre-trained ViT model from torchvision
model = models.vit_b_16(pretrained=True)
model.eval()

# Define image transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Download and transform an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert('RGB')
img_tensor = transform(img).unsqueeze(0)

# Inference
with torch.no_grad():
output = model(img_tensor)

# Print top prediction
_, predicted_idx = torch.max(output, 1)
print(f"Predicted class index: {predicted_idx.item()}")

Fine-tuning a Vision Transformer

Fine-tuning a pre-trained ViT for your specific task is often the most practical approach. Here's how to fine-tune a ViT for a custom dataset:

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import timm

# Define transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your dataset (assuming it's organized in folders by class)
train_dataset = ImageFolder('path/to/train_data', transform=transform)
val_dataset = ImageFolder('path/to/val_data', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Load pre-trained model
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=len(train_dataset.classes))

# Optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 5
for epoch in range(num_epochs):
# Training phase
model.train()
running_loss = 0.0

for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)

# Zero the gradients
optimizer.zero_grad()

# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()

running_loss += loss.item() * images.size(0)

epoch_loss = running_loss / len(train_dataset)

# Validation phase
model.eval()
correct = 0
total = 0

with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}")

# Save the fine-tuned model
torch.save(model.state_dict(), 'vit_finetuned.pth')

Real-World Applications

1. Medical Image Analysis

Vision Transformers have shown promising results in medical image analysis. Here's a simplified example for classifying X-ray images:

python
# Load a pre-trained ViT but modify the classification head
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, 2) # Binary classification: normal vs. pneumonia

# Dataset and training loop would be similar to the previous example
# ...

2. Object Detection with ViT

While the base ViT is designed for classification, adaptations like DETR (DEtection TRansformer) can be used for object detection tasks:

python
# This is a conceptual example of how you might use a ViT-based object detection model
from transformers import DetrForObjectDetection, DetrImageProcessor

# Load model and processor
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')

# Process image
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)

# Convert outputs to COCO API
target_sizes = torch.tensor([img.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)

3. Image Segmentation

Vision Transformers have been adapted for segmentation tasks through models like SETR (SEgmentation TRansformer):

python
# Conceptual example of using a ViT for semantic segmentation
import segmentation_models_pytorch as smp

# Create a U-Net model with ViT encoder
model = smp.Unet(
encoder_name="vit_base_patch16_384", # Using ViT as encoder
encoder_weights="imagenet",
classes=20, # Number of segmentation classes
activation="softmax",
)

# The training and inference process would be similar to standard segmentation models

Summary

In this tutorial, we've covered:

  1. The architecture and key components of Vision Transformers
  2. How to implement a basic ViT from scratch in PyTorch
  3. Using pre-trained models through libraries like timm and torchvision
  4. Fine-tuning ViT models for custom tasks
  5. Real-world applications in medical imaging, object detection, and segmentation

Vision Transformers represent an exciting shift in computer vision, offering an alternative to CNNs with competitive or superior performance across many tasks. While they may require more data and compute to train from scratch, pre-trained models make them accessible for a wide range of applications.

Additional Resources

  1. Original ViT Paper: "An Image is Worth 16x16 Words"
  2. PyTorch's Official ViT Implementation
  3. Hugging Face Transformers Library
  4. timm Library Documentation

Exercises

  1. Beginner: Modify the provided ViT code to use different patch sizes and observe how it affects the model.

  2. Intermediate: Implement a hybrid CNN-Transformer model where you use a small CNN to create patch embeddings instead of linear projections.

  3. Advanced: Train a ViT from scratch on a small dataset (like CIFAR-10) and compare its performance to a ResNet model with a similar number of parameters.

  4. Expert: Implement the attention visualization to understand which parts of the image the ViT is focusing on for making predictions.

Happy experimenting with Vision Transformers!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)