Skip to main content

PyTorch Image Segmentation

Introduction

Image segmentation is a crucial computer vision task that involves dividing an image into multiple segments or regions, each corresponding to a different object or part of an object. Unlike image classification, which assigns a single label to an entire image, segmentation assigns a label to each pixel in the image. This pixel-level classification enables applications like medical image analysis, autonomous driving, and scene understanding.

In this tutorial, we'll explore how to implement image segmentation using PyTorch, one of the most popular deep learning frameworks. We'll start with the fundamentals, then build and train a segmentation model on a real dataset.

Understanding Image Segmentation

Types of Image Segmentation

  1. Semantic Segmentation: Assigns a class label to each pixel without differentiating between instances of the same class.
  2. Instance Segmentation: Detects and segments each instance of an object separately, even if they belong to the same class.
  3. Panoptic Segmentation: Combines semantic and instance segmentation by identifying both "stuff" (background) and "things" (objects).

In this tutorial, we'll focus on semantic segmentation, which is the most common starting point for beginners.

Setting Up the Environment

Before we begin, let's ensure we have all the necessary packages:

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image

Dataset Preparation

For this tutorial, we'll use a subset of the PASCAL VOC dataset, which is commonly used for segmentation tasks. Let's create a custom dataset class to load our images and masks:

python
class SegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = list(sorted(os.listdir(image_dir)))
self.masks = list(sorted(os.listdir(mask_dir)))

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.masks[idx])

image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path) # Grayscale mask

if self.transform:
image = self.transform(image)
mask = self.transform(mask)

return image, mask

Now let's define our transformations and create data loaders:

python
# Define transformations
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])

# Create datasets
train_dataset = SegmentationDataset(
image_dir='path/to/train/images',
mask_dir='path/to/train/masks',
transform=transform
)

val_dataset = SegmentationDataset(
image_dir='path/to/val/images',
mask_dir='path/to/val/masks',
transform=transform
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

Building a Simple U-Net Model

One of the most popular architectures for image segmentation is U-Net. It consists of an encoder path (contracting) to capture context and a decoder path (expanding) that enables precise localization:

python
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.double_conv(x)

class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()

# Encoder (downsampling)
self.inc = DoubleConv(n_channels, 64)
self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512))

# Bridge
self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(512, 1024))

# Decoder (upsampling)
self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.up_conv1 = DoubleConv(1024, 512)

self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.up_conv2 = DoubleConv(512, 256)

self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.up_conv3 = DoubleConv(256, 128)

self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.up_conv4 = DoubleConv(128, 64)

# Output layer
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

def forward(self, x):
# Encoder path
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)

# Decoder path with skip connections
x = self.up1(x5)
x = torch.cat([x, x4], dim=1) # Skip connection
x = self.up_conv1(x)

x = self.up2(x)
x = torch.cat([x, x3], dim=1) # Skip connection
x = self.up_conv2(x)

x = self.up3(x)
x = torch.cat([x, x2], dim=1) # Skip connection
x = self.up_conv3(x)

x = self.up4(x)
x = torch.cat([x, x1], dim=1) # Skip connection
x = self.up_conv4(x)

x = self.outc(x)
return x

Training the Model

Now, let's define the training loop for our segmentation model:

python
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=21).to(device) # 21 classes for PASCAL VOC

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
model.train()
running_loss = 0.0

for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device).long() # Convert masks to long tensors

# Forward pass
outputs = model(images)
loss = criterion(outputs, masks.squeeze(1))

# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

running_loss += loss.item()

# Print statistics
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

# Validate the model
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, masks in val_loader:
images = images.to(device)
masks = masks.to(device).long()

outputs = model(images)
loss = criterion(outputs, masks.squeeze(1))

val_loss += loss.item()

print(f'Validation Loss: {val_loss/len(val_loader):.4f}')

print('Training completed!')

Visualizing Predictions

Let's create a function to visualize our segmentation results:

python
def visualize_prediction(model, image, mask, device):
model.eval()
with torch.no_grad():
# Prepare image
image_tensor = transforms.ToTensor()(image).unsqueeze(0).to(device)

# Get prediction
output = model(image_tensor)
pred_mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()

# Setup visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original image
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')

# Ground truth mask
axes[1].imshow(mask.squeeze().cpu().numpy())
axes[1].set_title('Ground Truth')
axes[1].axis('off')

# Predicted mask
axes[2].imshow(pred_mask)
axes[2].set_title('Prediction')
axes[2].axis('off')

plt.tight_layout()
plt.show()

# Example usage
sample_image, sample_mask = next(iter(val_loader))
sample_image = sample_image[0].permute(1, 2, 0).cpu().numpy()
visualize_prediction(model, sample_image, sample_mask[0], device)

Using Pretrained Models

For more advanced applications, we can leverage pretrained models from torchvision:

python
import torchvision.models.segmentation as seg_models

# Load a pretrained FCN model
fcn_model = seg_models.fcn_resnet50(pretrained=True)

# For fine-tuning on custom dataset, modify the classifier
fcn_model.classifier[4] = nn.Conv2d(512, num_classes, kernel_size=1)

# For DeepLabV3
# deeplabv3_model = seg_models.deeplabv3_resnet50(pretrained=True)
# deeplabv3_model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)

Real-world Application: Medical Image Segmentation

Image segmentation is widely used in medical imaging to identify and isolate regions of interest, such as tumors or organs. Let's look at a simplified example for brain MRI segmentation:

python
# Example for brain MRI segmentation
class BrainMRIDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
# Similar to our previous dataset class...
pass

def __getitem__(self, idx):
# Load MRI scan and its segmentation mask
pass

# Create a UNet model with fewer input channels (1 for grayscale MRI)
mri_model = UNet(n_channels=1, n_classes=4) # 4 classes: background, gray matter, white matter, CSF

# Training and evaluation would follow similar steps as before

Summary

In this tutorial, we've covered the fundamentals of image segmentation using PyTorch:

  1. We explored what image segmentation is and its various types
  2. We learned how to prepare datasets for segmentation tasks
  3. We implemented the U-Net architecture, a popular choice for segmentation
  4. We defined a training loop with appropriate loss functions
  5. We created visualization tools to evaluate our model's performance
  6. We saw how to use pretrained models for advanced applications
  7. We discussed a real-world example in medical imaging

Image segmentation is a powerful computer vision technique with applications across numerous domains including healthcare, autonomous driving, augmented reality, and satellite imaging. As you continue your journey in computer vision, experiment with different architectures and datasets to improve your segmentation models.

Additional Resources

  1. PyTorch Segmentation Models - A library with pre-implemented segmentation models
  2. PASCAL VOC Dataset - A standard benchmark for segmentation tasks
  3. COCO Dataset - A large-scale segmentation dataset
  4. U-Net Paper - The original U-Net research paper

Exercises

  1. Implement a simple image segmentation model for a binary task (e.g., foreground vs. background segmentation).
  2. Modify the U-Net architecture to include residual connections and evaluate if it improves performance.
  3. Fine-tune a pretrained DeepLabV3 model on a custom dataset of your choice.
  4. Implement data augmentation techniques specifically designed for segmentation tasks.
  5. Compare the performance of different loss functions (Dice Loss, Focal Loss, etc.) for imbalanced classes in segmentation.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)