PyTorch Dataset Class
Data is the foundation of machine learning models, and handling data efficiently is crucial for training deep learning models. PyTorch provides the Dataset
class as a fundamental building block for organizing and processing data. In this tutorial, we'll explore how to use and customize datasets in PyTorch to streamline your data handling workflow.
Introduction to PyTorch Dataset
The Dataset
class in PyTorch is an abstract class that represents a dataset. It's part of the torch.utils.data
module and provides a consistent way to access data samples. By using the Dataset class, you can:
- Load data from various sources (files, databases, web)
- Apply transformations to data
- Efficiently batch and shuffle your data
- Create custom datasets tailored to your specific needs
Let's dive into how to use this powerful tool in your PyTorch projects.
Basic Structure of a PyTorch Dataset
A PyTorch Dataset needs to implement three essential methods:
__init__
: Constructor that initializes the dataset__len__
: Returns the number of samples in the dataset__getitem__
: Returns a sample from the dataset at the given index
Here's a simple example of a Dataset class structure:
import torch
from torch.utils.data import Dataset
class SimpleDataset(Dataset):
def __init__(self, data_tensor, target_tensor):
self.data = data_tensor
self.targets = target_tensor
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index], self.targets[index]
Creating a Synthetic Dataset
Let's create a simple synthetic dataset to understand how the Dataset class works:
import torch
from torch.utils.data import Dataset, DataLoader
# Creating a synthetic dataset
class SyntheticDataset(Dataset):
def __init__(self, size=100, feature_dim=5):
# Generate random features and labels
self.features = torch.randn(size, feature_dim)
self.labels = torch.randint(0, 2, (size,)).float()
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
# Create a dataset instance
dataset = SyntheticDataset(size=100, feature_dim=5)
# Access a sample
sample_features, sample_label = dataset[0]
print(f"Features shape: {sample_features.shape}")
print(f"Features: {sample_features}")
print(f"Label: {sample_label}")
# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# Iterate through batches
for batch_idx, (features, labels) in enumerate(dataloader):
print(f"Batch {batch_idx} - Features shape: {features.shape}, Labels shape: {labels.shape}")
if batch_idx >= 2: # Show only first 3 batches
break
Sample output:
Features shape: torch.Size([5])
Features: tensor([-1.0786, 0.7983, -1.4072, -0.3799, -0.3789])
Label: tensor(0.)
Batch 0 - Features shape: torch.Size([10, 5]), Labels shape: torch.Size([10])
Batch 1 - Features shape: torch.Size([10, 5]), Labels shape: torch.Size([10])
Batch 2 - Features shape: torch.Size([10, 5]), Labels shape: torch.Size([10])
Loading Data from Files
In real-world scenarios, you'll likely load data from files. Let's create a Dataset that loads data from a CSV file:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
class CSVDataset(Dataset):
def __init__(self, csv_file, transform=None):
"""
Args:
csv_file (str): Path to the csv file
transform (callable, optional): Optional transform to be applied on a sample
"""
self.data_frame = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.data_frame)
def __getitem__(self, idx):
# Get features (all columns except the last one)
features = self.data_frame.iloc[idx, :-1].values.astype(np.float32)
# Get label (last column)
label = self.data_frame.iloc[idx, -1]
# Convert to tensors
features = torch.tensor(features)
label = torch.tensor(label, dtype=torch.long)
# Apply transforms if any
if self.transform:
features = self.transform(features)
return features, label
Working with Image Data
PyTorch is commonly used for image processing tasks. Here's an example of a Dataset for loading images:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class ImageDataset(Dataset):
def __init__(self, image_dir, transform=None):
"""
Args:
image_dir (str): Path to the image directory
transform (callable, optional): Optional transform to be applied on images
"""
self.image_dir = image_dir
self.transform = transform
self.image_files = [f for f in os.listdir(image_dir)
if os.path.isfile(os.path.join(image_dir, f))
and f.lower().endswith(('.png', '.jpg', '.jpeg'))]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = os.path.join(self.image_dir, self.image_files[idx])
image = Image.open(img_name).convert('RGB')
# Default transform if none provided
if self.transform is None:
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Apply transforms
image = self.transform(image)
# For this example, we're using the filename as a simple label
# In real applications, you would use actual labels
label_str = self.image_files[idx].split('_')[0] # Extract label from filename
try:
label = int(label_str)
except ValueError:
label = 0 # Default label if parsing fails
return image, label
Custom Transformations
Datasets often need transformations to preprocess the data. Let's see how to include custom transformations:
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class TransformableDataset(Dataset):
def __init__(self, data, targets, transform=None):
self.data = data
self.targets = targets
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
if self.transform:
x = self.transform(x)
return x, y
# Example usage
data = torch.randn(100, 3, 32, 32) # 100 RGB images of size 32x32
targets = torch.randint(0, 10, (100,)) # 100 random integer labels between 0 and 9
# Define transforms
transform = transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomHorizontalFlip(p=0.5)
])
# Create dataset with transforms
dataset = TransformableDataset(data, targets, transform=transform)
# Sample data
sample_x, sample_y = dataset[0]
print(f"Sample shape: {sample_x.shape}")
print(f"Sample label: {sample_y}")
Combining Multiple Datasets
Sometimes you need to combine multiple datasets. PyTorch provides the ConcatDataset
for this purpose:
from torch.utils.data import ConcatDataset, Subset
# Create two synthetic datasets
dataset1 = SyntheticDataset(size=50, feature_dim=5)
dataset2 = SyntheticDataset(size=70, feature_dim=5)
# Combine datasets
combined_dataset = ConcatDataset([dataset1, dataset2])
print(f"Dataset1 size: {len(dataset1)}")
print(f"Dataset2 size: {len(dataset2)}")
print(f"Combined dataset size: {len(combined_dataset)}")
# You can also create subsets of datasets
subset = Subset(combined_dataset, indices=[0, 10, 20, 30])
print(f"Subset size: {len(subset)}")
# Access a sample from the subset
sample_features, sample_label = subset[0]
print(f"Subset sample features: {sample_features}")
Sample output:
Dataset1 size: 50
Dataset2 size: 70
Combined dataset size: 120
Subset size: 4
Subset sample features: tensor([-0.2931, 0.5789, 0.4700, -0.0652, 0.5660])
Real-World Example: Custom Text Dataset
Let's create a more complex, real-world example: a dataset for text classification:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import re
from collections import Counter
class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, vocab=None, max_len=100):
"""
Args:
texts (list): List of text strings
labels (list): List of corresponding labels
vocab (dict, optional): Vocabulary mapping words to indices
max_len (int): Maximum sequence length
"""
self.texts = texts
self.labels = labels
self.max_len = max_len
# Create vocabulary if not provided
if vocab is None:
self.vocab = self._create_vocabulary(texts)
else:
self.vocab = vocab
# Add special tokens
if '<PAD>' not in self.vocab:
self.vocab['<PAD>'] = len(self.vocab)
if '<UNK>' not in self.vocab:
self.vocab['<UNK>'] = len(self.vocab)
self.pad_idx = self.vocab['<PAD>']
self.unk_idx = self.vocab['<UNK>']
def _create_vocabulary(self, texts, max_vocab_size=10000):
# Simple tokenization and vocab creation
all_words = []
for text in texts:
words = re.findall(r'\w+', text.lower())
all_words.extend(words)
word_counts = Counter(all_words)
common_words = word_counts.most_common(max_vocab_size)
vocab = {}
for i, (word, _) in enumerate(common_words):
vocab[word] = i
return vocab
def _tokenize(self, text):
# Simple tokenizer
tokens = re.findall(r'\w+', text.lower())
return tokens
def _convert_to_indices(self, tokens):
# Convert tokens to indices
indices = []
for token in tokens:
if token in self.vocab:
indices.append(self.vocab[token])
else:
indices.append(self.unk_idx)
return indices
def _pad_sequence(self, indices):
# Pad or truncate sequence to max_len
if len(indices) >= self.max_len:
return indices[:self.max_len]
else:
padding = [self.pad_idx] * (self.max_len - len(indices))
return indices + padding
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
# Process the text
tokens = self._tokenize(text)
indices = self._convert_to_indices(tokens)
padded_indices = self._pad_sequence(indices)
# Convert to tensor
text_tensor = torch.tensor(padded_indices, dtype=torch.long)
label_tensor = torch.tensor(label, dtype=torch.long)
return text_tensor, label_tensor
# Example usage
texts = [
"This movie is amazing and I loved it!",
"This was the worst film I've ever seen.",
"The acting was great but the plot was confusing.",
"I would recommend this to everyone!"
]
labels = [1, 0, 0, 1] # 1 for positive, 0 for negative
# Create dataset
text_dataset = TextClassificationDataset(texts, labels, max_len=20)
# Check vocabulary size
print(f"Vocabulary size: {len(text_dataset.vocab)}")
# Get a sample
sample_text, sample_label = text_dataset[0]
print(f"Sample token indices: {sample_text}")
print(f"Original text: {texts[0]}")
print(f"Label: {sample_label}")
# Create DataLoader
text_dataloader = DataLoader(text_dataset, batch_size=2, shuffle=True)
# Get a batch
for batch_texts, batch_labels in text_dataloader:
print(f"Batch shape: {batch_texts.shape}")
print(f"Labels: {batch_labels}")
break
Sample output:
Vocabulary size: 28
Sample token indices: tensor([19, 1, 16, 3, 12, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0])
Original text: This movie is amazing and I loved it!
Label: tensor(1)
Batch shape: torch.Size([2, 20])
Labels: tensor([0, 0])
Advanced Topic: Using Dataset with map-style and iterable-style
PyTorch supports two types of datasets: map-style and iterable-style. We've been focusing on map-style datasets, but let's briefly explore iterable-style datasets:
from torch.utils.data import IterableDataset
class StreamingDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
return iter(range(self.start, self.end))
# Create iterable dataset
stream_dataset = StreamingDataset(1, 1000)
# Create dataloader
stream_loader = DataLoader(stream_dataset, batch_size=5)
# Iterate over dataloader
for i, batch in enumerate(stream_loader):
print(f"Batch {i}: {batch}")
if i >= 2:
break
Sample output:
Batch 0: tensor([1, 2, 3, 4, 5])
Batch 1: tensor([6, 7, 8, 9, 10])
Batch 2: tensor([11, 12, 13, 14, 15])
Summary
The PyTorch Dataset class is a powerful abstraction that allows you to manage and preprocess your data efficiently. In this tutorial, we've covered:
- Creating basic PyTorch Datasets
- Working with different data types (numerical data, images, text)
- Applying transformations to your data
- Combining and subsetting datasets
- Creating a real-world text classification dataset
- Introduction to iterable-style datasets
With these tools, you can build custom datasets for any machine learning task, making your data handling code more organized and efficient.
Additional Resources
Exercises
- Create a custom Dataset to load and preprocess the MNIST dataset without using torchvision.datasets.MNIST
- Implement a Dataset that loads CSV data and applies different transformations to numerical features
- Build a Dataset for a time series prediction task that creates sequences of fixed length
- Create a Dataset that combines text and image data (e.g., for image captioning)
- Implement a custom sampling strategy using WeightedRandomSampler to handle imbalanced datasets
By mastering PyTorch's Dataset class, you'll be able to handle data efficiently for any machine learning task, from basic classification to complex multi-modal learning problems.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)