Skip to main content

PyTorch Custom Datasets

In many real-world machine learning projects, you'll need to work with your own data rather than using standard datasets. PyTorch provides a flexible way to create custom datasets through its Dataset class. This tutorial will guide you through creating custom datasets in PyTorch, enabling you to efficiently load, transform, and feed your unique data into deep learning models.

Understanding the Dataset Class

PyTorch's Dataset class is an abstract class that serves as a base for all datasets. To create a custom dataset, you need to subclass Dataset and implement at least two essential methods:

  1. __len__: Returns the size of the dataset
  2. __getitem__: Returns a sample from the dataset at a given index

Let's dive into how to implement these methods to create a custom dataset.

Creating a Simple Custom Dataset

First, let's create a simple custom dataset that generates synthetic data. This example will help you understand the basic structure of a custom dataset:

python
import torch
from torch.utils.data import Dataset

class SyntheticDataset(Dataset):
def __init__(self, size=1000, dimensions=10):
"""
Initialize a synthetic dataset with random data.

Args:
size (int): Number of samples in the dataset
dimensions (int): Number of features for each sample
"""
self.size = size
# Generate random features and labels
self.data = torch.randn(size, dimensions)
self.labels = torch.randint(0, 2, (size,)) # Binary labels (0 or 1)

def __len__(self):
"""Return the size of the dataset."""
return self.size

def __getitem__(self, idx):
"""
Get a sample from the dataset.

Args:
idx (int): Index of the sample

Returns:
tuple: (features, label) corresponding to the sample
"""
return self.data[idx], self.labels[idx]

# Create and use the synthetic dataset
dataset = SyntheticDataset(size=100)
print(f"Dataset size: {len(dataset)}")

# Access a single sample
features, label = dataset[0]
print(f"Sample features shape: {features.shape}")
print(f"Sample label: {label}")

Output:

Dataset size: 100
Sample features shape: torch.Size([10])
Sample label: tensor(0) # This value may vary due to randomness

Loading Data from Files

Most real-world scenarios involve loading data from files (images, text, etc.). Let's create a custom dataset to load image data from a directory:

python
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
Load images from a directory.

Args:
root_dir (str): Directory with the images
transform (callable, optional): Optional transform to apply to the images
"""
self.root_dir = root_dir
self.transform = transform
self.image_files = [f for f in os.listdir(root_dir)
if os.path.isfile(os.path.join(root_dir, f)) and
f.lower().endswith(('.png', '.jpg', '.jpeg'))]

def __len__(self):
return len(self.image_files)

def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_name).convert('RGB')

if self.transform:
image = self.transform(image)

# In a real scenario, labels might be derived from the file name or a separate file
# For this example, we just create dummy labels
label = 0 if 'cat' in self.image_files[idx].lower() else 1 # Assume 0 for cat, 1 for dog

return image, label

# Example usage
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Replace 'path/to/images' with your actual image directory
# image_dataset = ImageDataset(root_dir='path/to/images', transform=transform)

Loading Data from CSV Files

Let's create a dataset class for loading data from CSV files, a common format for tabular data:

python
import pandas as pd
import torch
from torch.utils.data import Dataset

class CSVDataset(Dataset):
def __init__(self, csv_file, transform=None):
"""
Load data from a CSV file.

Args:
csv_file (str): Path to the CSV file
transform (callable, optional): Optional transform to apply to the features
"""
self.data_frame = pd.read_csv(csv_file)
self.transform = transform

def __len__(self):
return len(self.data_frame)

def __getitem__(self, idx):
# Assuming the last column is the label and the rest are features
features = self.data_frame.iloc[idx, :-1].values.astype(float)
label = self.data_frame.iloc[idx, -1]

# Convert to torch tensors
features = torch.tensor(features, dtype=torch.float32)
label = torch.tensor(label, dtype=torch.long)

if self.transform:
features = self.transform(features)

return features, label

Custom Dataset with Data Transformations

Often, you'll need to preprocess your data before feeding it into a model. Let's modify our custom dataset to include transformations:

python
import torch
from torch.utils.data import Dataset

class CustomDatasetWithTransform(Dataset):
def __init__(self, data, labels, transform=None):
"""
Initialize dataset with data, labels and optional transformations.

Args:
data (array-like): Features of the dataset
labels (array-like): Labels of the dataset
transform (callable, optional): Optional transform to apply to the data
"""
self.data = torch.tensor(data, dtype=torch.float32)
self.labels = torch.tensor(labels, dtype=torch.long)
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
features = self.data[idx]
label = self.labels[idx]

if self.transform:
features = self.transform(features)

return features, label

# Example transformation function
class AddGaussianNoise:
def __init__(self, mean=0., std=1.):
self.mean = mean
self.std = std

def __call__(self, tensor):
return tensor + torch.randn_like(tensor) * self.std + self.mean

# Example usage
import numpy as np

# Generate sample data
X = np.random.randn(100, 5) # 100 samples, 5 features
y = np.random.randint(0, 3, 100) # 3 classes

# Create dataset with a transform to add noise
transform = AddGaussianNoise(std=0.1)
dataset = CustomDatasetWithTransform(X, y, transform=transform)

# Check the first sample
features, label = dataset[0]
print(f"Features: {features}")
print(f"Label: {label}")

Output:

Features: tensor([-1.0754,  0.1570,  0.0463, -0.3453,  1.6221])  # Values will vary
Label: tensor(1) # This value may vary due to randomness

Using Custom Datasets with DataLoader

To efficiently batch and shuffle your data during training, you should use PyTorch's DataLoader with your custom dataset:

python
from torch.utils.data import DataLoader

# Create a synthetic dataset
dataset = SyntheticDataset(size=1000)

# Create a DataLoader
dataloader = DataLoader(
dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=4
)

# Iterate over the data
for batch_idx, (features, labels) in enumerate(dataloader):
print(f"Batch {batch_idx + 1}:")
print(f" Features shape: {features.shape}")
print(f" Labels shape: {labels.shape}")

# Just print the first few batches
if batch_idx == 2:
break

Output:

Batch 1:
Features shape: torch.Size([32, 10])
Labels shape: torch.Size([32])
Batch 2:
Features shape: torch.Size([32, 10])
Labels shape: torch.Size([32])
Batch 3:
Features shape: torch.Size([32, 10])
Labels shape: torch.Size([32])

Real-World Example: Text Classification Dataset

Let's create a more practical example: a dataset for text classification using tokenization and embeddings:

python
import torch
from torch.utils.data import Dataset

class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
"""
Dataset for text classification tasks.

Args:
texts (list): List of text strings
labels (list): List of labels corresponding to the texts
tokenizer: Tokenizer to convert text to tokens
max_length (int): Maximum sequence length for padding/truncation
"""
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]

# Tokenize the text
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_length,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)

# Convert to tensors
input_ids = encoding['input_ids'].flatten()
attention_mask = encoding['attention_mask'].flatten()

return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'label': torch.tensor(label, dtype=torch.long)
}

# Example usage (requires transformers library)
# from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# example_texts = [
# "I loved this movie!",
# "This film was terrible.",
# "An average movie, nothing special."
# ]
# example_labels = [1, 0, 0.5] # 1 for positive, 0 for negative, 0.5 for neutral

# text_dataset = TextClassificationDataset(example_texts, example_labels, tokenizer)

Tips for Efficient Custom Datasets

  1. Lazy Loading: For large datasets, consider loading data on-demand in __getitem__ rather than loading everything in __init__.

  2. Caching: Implement caching mechanisms for data that's expensive to load or process.

python
class CachedImageDataset(Dataset):
def __init__(self, root_dir, transform=None, max_cache_size=100):
self.root_dir = root_dir
self.transform = transform
self.image_files = [f for f in os.listdir(root_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
self.max_cache_size = max_cache_size
self.cache = {}

def __len__(self):
return len(self.image_files)

def __getitem__(self, idx):
if idx in self.cache:
image = self.cache[idx]
else:
img_name = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_name).convert('RGB')

# Manage cache size
if len(self.cache) >= self.max_cache_size:
self.cache.pop(next(iter(self.cache)))

self.cache[idx] = image

if self.transform:
image = self.transform(image)

# Dummy label for demo
label = 0

return image, label
  1. Multiprocessing: Use num_workers parameter in DataLoader to parallelize data loading.
python
dataloader = DataLoader(
dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=4, # Adjust based on your CPU
pin_memory=True # Faster data transfer to GPU
)

Handling Imbalanced Datasets

For imbalanced datasets, you can create a sampler to adjust the sampling frequency:

python
from torch.utils.data import WeightedRandomSampler

class ImbalancedDataset(Dataset):
def __init__(self, features, labels):
self.features = features
self.labels = labels

def __len__(self):
return len(self.labels)

def __getitem__(self, idx):
return self.features[idx], self.labels[idx]

# Example for creating a weighted sampler
def create_weighted_sampler(labels):
"""
Create a weighted sampler to handle class imbalance.

Args:
labels (list): Dataset labels

Returns:
WeightedRandomSampler: Sampler that balances class distribution
"""
class_counts = torch.bincount(torch.tensor(labels))
class_weights = 1. / class_counts.float()

sample_weights = [class_weights[label] for label in labels]
sample_weights = torch.tensor(sample_weights)

sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)

return sampler

# Example usage
labels = [0, 0, 0, 1, 1, 2, 2, 2, 2] # Imbalanced labels
features = torch.randn(len(labels), 5) # Random features
dataset = ImbalancedDataset(features, labels)

# Create weighted sampler
sampler = create_weighted_sampler(labels)

# Use sampler with DataLoader
dataloader = DataLoader(
dataset=dataset,
batch_size=3,
sampler=sampler # Note: When using a sampler, shuffle should be False
)

# Verify the balanced sampling
sampled_labels = []
for _, labels_batch in dataloader:
sampled_labels.extend(labels_batch.tolist())
if len(sampled_labels) >= 15:
break

print("Sampled labels:", sampled_labels[:15])
print("Label distribution:", torch.bincount(torch.tensor(sampled_labels[:15])))

Summary

In this tutorial, we learned how to:

  1. Create custom PyTorch datasets by subclassing the Dataset class
  2. Implement the essential __len__ and __getitem__ methods
  3. Load different types of data, including in-memory, images, and text
  4. Apply data transformations to custom datasets
  5. Use custom datasets with DataLoader for efficient batch processing
  6. Handle large datasets with caching and lazy loading
  7. Deal with imbalanced datasets using weighted sampling

Custom datasets are a powerful tool in PyTorch, allowing you to work with any type of data while maintaining the efficiency and convenience of PyTorch's data loading utilities.

Additional Resources

Exercises

  1. Create a custom dataset that loads data from a CSV file and performs normalization on numerical features.
  2. Implement a custom dataset for time series data that creates sequences of a specified length from continuous data.
  3. Extend the ImageDataset class to support data augmentation techniques like random cropping, flipping, and color jittering.
  4. Create a custom dataset for pair-based learning (e.g., siamese networks) that returns pairs of samples with a similarity label.
  5. Implement a custom dataset with caching for audio files that computes mel-spectrograms as features.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)