PyTorch Custom Datasets
In many real-world machine learning projects, you'll need to work with your own data rather than using standard datasets. PyTorch provides a flexible way to create custom datasets through its Dataset
class. This tutorial will guide you through creating custom datasets in PyTorch, enabling you to efficiently load, transform, and feed your unique data into deep learning models.
Understanding the Dataset Class
PyTorch's Dataset
class is an abstract class that serves as a base for all datasets. To create a custom dataset, you need to subclass Dataset
and implement at least two essential methods:
__len__
: Returns the size of the dataset__getitem__
: Returns a sample from the dataset at a given index
Let's dive into how to implement these methods to create a custom dataset.
Creating a Simple Custom Dataset
First, let's create a simple custom dataset that generates synthetic data. This example will help you understand the basic structure of a custom dataset:
import torch
from torch.utils.data import Dataset
class SyntheticDataset(Dataset):
def __init__(self, size=1000, dimensions=10):
"""
Initialize a synthetic dataset with random data.
Args:
size (int): Number of samples in the dataset
dimensions (int): Number of features for each sample
"""
self.size = size
# Generate random features and labels
self.data = torch.randn(size, dimensions)
self.labels = torch.randint(0, 2, (size,)) # Binary labels (0 or 1)
def __len__(self):
"""Return the size of the dataset."""
return self.size
def __getitem__(self, idx):
"""
Get a sample from the dataset.
Args:
idx (int): Index of the sample
Returns:
tuple: (features, label) corresponding to the sample
"""
return self.data[idx], self.labels[idx]
# Create and use the synthetic dataset
dataset = SyntheticDataset(size=100)
print(f"Dataset size: {len(dataset)}")
# Access a single sample
features, label = dataset[0]
print(f"Sample features shape: {features.shape}")
print(f"Sample label: {label}")
Output:
Dataset size: 100
Sample features shape: torch.Size([10])
Sample label: tensor(0) # This value may vary due to randomness
Loading Data from Files
Most real-world scenarios involve loading data from files (images, text, etc.). Let's create a custom dataset to load image data from a directory:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
Load images from a directory.
Args:
root_dir (str): Directory with the images
transform (callable, optional): Optional transform to apply to the images
"""
self.root_dir = root_dir
self.transform = transform
self.image_files = [f for f in os.listdir(root_dir)
if os.path.isfile(os.path.join(root_dir, f)) and
f.lower().endswith(('.png', '.jpg', '.jpeg'))]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_name).convert('RGB')
if self.transform:
image = self.transform(image)
# In a real scenario, labels might be derived from the file name or a separate file
# For this example, we just create dummy labels
label = 0 if 'cat' in self.image_files[idx].lower() else 1 # Assume 0 for cat, 1 for dog
return image, label
# Example usage
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Replace 'path/to/images' with your actual image directory
# image_dataset = ImageDataset(root_dir='path/to/images', transform=transform)
Loading Data from CSV Files
Let's create a dataset class for loading data from CSV files, a common format for tabular data:
import pandas as pd
import torch
from torch.utils.data import Dataset
class CSVDataset(Dataset):
def __init__(self, csv_file, transform=None):
"""
Load data from a CSV file.
Args:
csv_file (str): Path to the CSV file
transform (callable, optional): Optional transform to apply to the features
"""
self.data_frame = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.data_frame)
def __getitem__(self, idx):
# Assuming the last column is the label and the rest are features
features = self.data_frame.iloc[idx, :-1].values.astype(float)
label = self.data_frame.iloc[idx, -1]
# Convert to torch tensors
features = torch.tensor(features, dtype=torch.float32)
label = torch.tensor(label, dtype=torch.long)
if self.transform:
features = self.transform(features)
return features, label
Custom Dataset with Data Transformations
Often, you'll need to preprocess your data before feeding it into a model. Let's modify our custom dataset to include transformations:
import torch
from torch.utils.data import Dataset
class CustomDatasetWithTransform(Dataset):
def __init__(self, data, labels, transform=None):
"""
Initialize dataset with data, labels and optional transformations.
Args:
data (array-like): Features of the dataset
labels (array-like): Labels of the dataset
transform (callable, optional): Optional transform to apply to the data
"""
self.data = torch.tensor(data, dtype=torch.float32)
self.labels = torch.tensor(labels, dtype=torch.long)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
features = self.data[idx]
label = self.labels[idx]
if self.transform:
features = self.transform(features)
return features, label
# Example transformation function
class AddGaussianNoise:
def __init__(self, mean=0., std=1.):
self.mean = mean
self.std = std
def __call__(self, tensor):
return tensor + torch.randn_like(tensor) * self.std + self.mean
# Example usage
import numpy as np
# Generate sample data
X = np.random.randn(100, 5) # 100 samples, 5 features
y = np.random.randint(0, 3, 100) # 3 classes
# Create dataset with a transform to add noise
transform = AddGaussianNoise(std=0.1)
dataset = CustomDatasetWithTransform(X, y, transform=transform)
# Check the first sample
features, label = dataset[0]
print(f"Features: {features}")
print(f"Label: {label}")
Output:
Features: tensor([-1.0754, 0.1570, 0.0463, -0.3453, 1.6221]) # Values will vary
Label: tensor(1) # This value may vary due to randomness
Using Custom Datasets with DataLoader
To efficiently batch and shuffle your data during training, you should use PyTorch's DataLoader
with your custom dataset:
from torch.utils.data import DataLoader
# Create a synthetic dataset
dataset = SyntheticDataset(size=1000)
# Create a DataLoader
dataloader = DataLoader(
dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
# Iterate over the data
for batch_idx, (features, labels) in enumerate(dataloader):
print(f"Batch {batch_idx + 1}:")
print(f" Features shape: {features.shape}")
print(f" Labels shape: {labels.shape}")
# Just print the first few batches
if batch_idx == 2:
break
Output:
Batch 1:
Features shape: torch.Size([32, 10])
Labels shape: torch.Size([32])
Batch 2:
Features shape: torch.Size([32, 10])
Labels shape: torch.Size([32])
Batch 3:
Features shape: torch.Size([32, 10])
Labels shape: torch.Size([32])
Real-World Example: Text Classification Dataset
Let's create a more practical example: a dataset for text classification using tokenization and embeddings:
import torch
from torch.utils.data import Dataset
class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
"""
Dataset for text classification tasks.
Args:
texts (list): List of text strings
labels (list): List of labels corresponding to the texts
tokenizer: Tokenizer to convert text to tokens
max_length (int): Maximum sequence length for padding/truncation
"""
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# Tokenize the text
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_length,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
# Convert to tensors
input_ids = encoding['input_ids'].flatten()
attention_mask = encoding['attention_mask'].flatten()
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'label': torch.tensor(label, dtype=torch.long)
}
# Example usage (requires transformers library)
# from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# example_texts = [
# "I loved this movie!",
# "This film was terrible.",
# "An average movie, nothing special."
# ]
# example_labels = [1, 0, 0.5] # 1 for positive, 0 for negative, 0.5 for neutral
# text_dataset = TextClassificationDataset(example_texts, example_labels, tokenizer)
Tips for Efficient Custom Datasets
-
Lazy Loading: For large datasets, consider loading data on-demand in
__getitem__
rather than loading everything in__init__
. -
Caching: Implement caching mechanisms for data that's expensive to load or process.
class CachedImageDataset(Dataset):
def __init__(self, root_dir, transform=None, max_cache_size=100):
self.root_dir = root_dir
self.transform = transform
self.image_files = [f for f in os.listdir(root_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
self.max_cache_size = max_cache_size
self.cache = {}
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
if idx in self.cache:
image = self.cache[idx]
else:
img_name = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_name).convert('RGB')
# Manage cache size
if len(self.cache) >= self.max_cache_size:
self.cache.pop(next(iter(self.cache)))
self.cache[idx] = image
if self.transform:
image = self.transform(image)
# Dummy label for demo
label = 0
return image, label
- Multiprocessing: Use
num_workers
parameter in DataLoader to parallelize data loading.
dataloader = DataLoader(
dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=4, # Adjust based on your CPU
pin_memory=True # Faster data transfer to GPU
)
Handling Imbalanced Datasets
For imbalanced datasets, you can create a sampler to adjust the sampling frequency:
from torch.utils.data import WeightedRandomSampler
class ImbalancedDataset(Dataset):
def __init__(self, features, labels):
self.features = features
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
# Example for creating a weighted sampler
def create_weighted_sampler(labels):
"""
Create a weighted sampler to handle class imbalance.
Args:
labels (list): Dataset labels
Returns:
WeightedRandomSampler: Sampler that balances class distribution
"""
class_counts = torch.bincount(torch.tensor(labels))
class_weights = 1. / class_counts.float()
sample_weights = [class_weights[label] for label in labels]
sample_weights = torch.tensor(sample_weights)
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
return sampler
# Example usage
labels = [0, 0, 0, 1, 1, 2, 2, 2, 2] # Imbalanced labels
features = torch.randn(len(labels), 5) # Random features
dataset = ImbalancedDataset(features, labels)
# Create weighted sampler
sampler = create_weighted_sampler(labels)
# Use sampler with DataLoader
dataloader = DataLoader(
dataset=dataset,
batch_size=3,
sampler=sampler # Note: When using a sampler, shuffle should be False
)
# Verify the balanced sampling
sampled_labels = []
for _, labels_batch in dataloader:
sampled_labels.extend(labels_batch.tolist())
if len(sampled_labels) >= 15:
break
print("Sampled labels:", sampled_labels[:15])
print("Label distribution:", torch.bincount(torch.tensor(sampled_labels[:15])))
Summary
In this tutorial, we learned how to:
- Create custom PyTorch datasets by subclassing the
Dataset
class - Implement the essential
__len__
and__getitem__
methods - Load different types of data, including in-memory, images, and text
- Apply data transformations to custom datasets
- Use custom datasets with DataLoader for efficient batch processing
- Handle large datasets with caching and lazy loading
- Deal with imbalanced datasets using weighted sampling
Custom datasets are a powerful tool in PyTorch, allowing you to work with any type of data while maintaining the efficiency and convenience of PyTorch's data loading utilities.
Additional Resources
- PyTorch Documentation: Building Custom Datasets
- PyTorch Documentation: DataLoader
- PyTorch Vision Datasets - Good examples of dataset implementations
Exercises
- Create a custom dataset that loads data from a CSV file and performs normalization on numerical features.
- Implement a custom dataset for time series data that creates sequences of a specified length from continuous data.
- Extend the ImageDataset class to support data augmentation techniques like random cropping, flipping, and color jittering.
- Create a custom dataset for pair-based learning (e.g., siamese networks) that returns pairs of samples with a similarity label.
- Implement a custom dataset with caching for audio files that computes mel-spectrograms as features.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)