Skip to main content

PyTorch Dataset Class

Data is the foundation of machine learning models, and handling data efficiently is crucial for training deep learning models. PyTorch provides the Dataset class as a fundamental building block for organizing and processing data. In this tutorial, we'll explore how to use and customize datasets in PyTorch to streamline your data handling workflow.

Introduction to PyTorch Dataset

The Dataset class in PyTorch is an abstract class that represents a dataset. It's part of the torch.utils.data module and provides a consistent way to access data samples. By using the Dataset class, you can:

  • Load data from various sources (files, databases, web)
  • Apply transformations to data
  • Efficiently batch and shuffle your data
  • Create custom datasets tailored to your specific needs

Let's dive into how to use this powerful tool in your PyTorch projects.

Basic Structure of a PyTorch Dataset

A PyTorch Dataset needs to implement three essential methods:

  1. __init__: Constructor that initializes the dataset
  2. __len__: Returns the number of samples in the dataset
  3. __getitem__: Returns a sample from the dataset at the given index

Here's a simple example of a Dataset class structure:

python
import torch
from torch.utils.data import Dataset

class SimpleDataset(Dataset):
def __init__(self, data_tensor, target_tensor):
self.data = data_tensor
self.targets = target_tensor

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index], self.targets[index]

Creating a Synthetic Dataset

Let's create a simple synthetic dataset to understand how the Dataset class works:

python
import torch
from torch.utils.data import Dataset, DataLoader

# Creating a synthetic dataset
class SyntheticDataset(Dataset):
def __init__(self, size=100, feature_dim=5):
# Generate random features and labels
self.features = torch.randn(size, feature_dim)
self.labels = torch.randint(0, 2, (size,)).float()

def __len__(self):
return len(self.features)

def __getitem__(self, idx):
return self.features[idx], self.labels[idx]

# Create a dataset instance
dataset = SyntheticDataset(size=100, feature_dim=5)

# Access a sample
sample_features, sample_label = dataset[0]
print(f"Features shape: {sample_features.shape}")
print(f"Features: {sample_features}")
print(f"Label: {sample_label}")

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# Iterate through batches
for batch_idx, (features, labels) in enumerate(dataloader):
print(f"Batch {batch_idx} - Features shape: {features.shape}, Labels shape: {labels.shape}")
if batch_idx >= 2: # Show only first 3 batches
break

Sample output:

Features shape: torch.Size([5])
Features: tensor([-1.0786, 0.7983, -1.4072, -0.3799, -0.3789])
Label: tensor(0.)
Batch 0 - Features shape: torch.Size([10, 5]), Labels shape: torch.Size([10])
Batch 1 - Features shape: torch.Size([10, 5]), Labels shape: torch.Size([10])
Batch 2 - Features shape: torch.Size([10, 5]), Labels shape: torch.Size([10])

Loading Data from Files

In real-world scenarios, you'll likely load data from files. Let's create a Dataset that loads data from a CSV file:

python
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset

class CSVDataset(Dataset):
def __init__(self, csv_file, transform=None):
"""
Args:
csv_file (str): Path to the csv file
transform (callable, optional): Optional transform to be applied on a sample
"""
self.data_frame = pd.read_csv(csv_file)
self.transform = transform

def __len__(self):
return len(self.data_frame)

def __getitem__(self, idx):
# Get features (all columns except the last one)
features = self.data_frame.iloc[idx, :-1].values.astype(np.float32)
# Get label (last column)
label = self.data_frame.iloc[idx, -1]

# Convert to tensors
features = torch.tensor(features)
label = torch.tensor(label, dtype=torch.long)

# Apply transforms if any
if self.transform:
features = self.transform(features)

return features, label

Working with Image Data

PyTorch is commonly used for image processing tasks. Here's an example of a Dataset for loading images:

python
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class ImageDataset(Dataset):
def __init__(self, image_dir, transform=None):
"""
Args:
image_dir (str): Path to the image directory
transform (callable, optional): Optional transform to be applied on images
"""
self.image_dir = image_dir
self.transform = transform
self.image_files = [f for f in os.listdir(image_dir)
if os.path.isfile(os.path.join(image_dir, f))
and f.lower().endswith(('.png', '.jpg', '.jpeg'))]

def __len__(self):
return len(self.image_files)

def __getitem__(self, idx):
img_name = os.path.join(self.image_dir, self.image_files[idx])
image = Image.open(img_name).convert('RGB')

# Default transform if none provided
if self.transform is None:
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])

# Apply transforms
image = self.transform(image)

# For this example, we're using the filename as a simple label
# In real applications, you would use actual labels
label_str = self.image_files[idx].split('_')[0] # Extract label from filename
try:
label = int(label_str)
except ValueError:
label = 0 # Default label if parsing fails

return image, label

Custom Transformations

Datasets often need transformations to preprocess the data. Let's see how to include custom transformations:

python
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class TransformableDataset(Dataset):
def __init__(self, data, targets, transform=None):
self.data = data
self.targets = targets
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]

if self.transform:
x = self.transform(x)

return x, y

# Example usage
data = torch.randn(100, 3, 32, 32) # 100 RGB images of size 32x32
targets = torch.randint(0, 10, (100,)) # 100 random integer labels between 0 and 9

# Define transforms
transform = transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomHorizontalFlip(p=0.5)
])

# Create dataset with transforms
dataset = TransformableDataset(data, targets, transform=transform)

# Sample data
sample_x, sample_y = dataset[0]
print(f"Sample shape: {sample_x.shape}")
print(f"Sample label: {sample_y}")

Combining Multiple Datasets

Sometimes you need to combine multiple datasets. PyTorch provides the ConcatDataset for this purpose:

python
from torch.utils.data import ConcatDataset, Subset

# Create two synthetic datasets
dataset1 = SyntheticDataset(size=50, feature_dim=5)
dataset2 = SyntheticDataset(size=70, feature_dim=5)

# Combine datasets
combined_dataset = ConcatDataset([dataset1, dataset2])

print(f"Dataset1 size: {len(dataset1)}")
print(f"Dataset2 size: {len(dataset2)}")
print(f"Combined dataset size: {len(combined_dataset)}")

# You can also create subsets of datasets
subset = Subset(combined_dataset, indices=[0, 10, 20, 30])
print(f"Subset size: {len(subset)}")

# Access a sample from the subset
sample_features, sample_label = subset[0]
print(f"Subset sample features: {sample_features}")

Sample output:

Dataset1 size: 50
Dataset2 size: 70
Combined dataset size: 120
Subset size: 4
Subset sample features: tensor([-0.2931, 0.5789, 0.4700, -0.0652, 0.5660])

Real-World Example: Custom Text Dataset

Let's create a more complex, real-world example: a dataset for text classification:

python
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import re
from collections import Counter

class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, vocab=None, max_len=100):
"""
Args:
texts (list): List of text strings
labels (list): List of corresponding labels
vocab (dict, optional): Vocabulary mapping words to indices
max_len (int): Maximum sequence length
"""
self.texts = texts
self.labels = labels
self.max_len = max_len

# Create vocabulary if not provided
if vocab is None:
self.vocab = self._create_vocabulary(texts)
else:
self.vocab = vocab

# Add special tokens
if '<PAD>' not in self.vocab:
self.vocab['<PAD>'] = len(self.vocab)
if '<UNK>' not in self.vocab:
self.vocab['<UNK>'] = len(self.vocab)

self.pad_idx = self.vocab['<PAD>']
self.unk_idx = self.vocab['<UNK>']

def _create_vocabulary(self, texts, max_vocab_size=10000):
# Simple tokenization and vocab creation
all_words = []
for text in texts:
words = re.findall(r'\w+', text.lower())
all_words.extend(words)

word_counts = Counter(all_words)
common_words = word_counts.most_common(max_vocab_size)

vocab = {}
for i, (word, _) in enumerate(common_words):
vocab[word] = i

return vocab

def _tokenize(self, text):
# Simple tokenizer
tokens = re.findall(r'\w+', text.lower())
return tokens

def _convert_to_indices(self, tokens):
# Convert tokens to indices
indices = []
for token in tokens:
if token in self.vocab:
indices.append(self.vocab[token])
else:
indices.append(self.unk_idx)
return indices

def _pad_sequence(self, indices):
# Pad or truncate sequence to max_len
if len(indices) >= self.max_len:
return indices[:self.max_len]
else:
padding = [self.pad_idx] * (self.max_len - len(indices))
return indices + padding

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]

# Process the text
tokens = self._tokenize(text)
indices = self._convert_to_indices(tokens)
padded_indices = self._pad_sequence(indices)

# Convert to tensor
text_tensor = torch.tensor(padded_indices, dtype=torch.long)
label_tensor = torch.tensor(label, dtype=torch.long)

return text_tensor, label_tensor

# Example usage
texts = [
"This movie is amazing and I loved it!",
"This was the worst film I've ever seen.",
"The acting was great but the plot was confusing.",
"I would recommend this to everyone!"
]
labels = [1, 0, 0, 1] # 1 for positive, 0 for negative

# Create dataset
text_dataset = TextClassificationDataset(texts, labels, max_len=20)

# Check vocabulary size
print(f"Vocabulary size: {len(text_dataset.vocab)}")

# Get a sample
sample_text, sample_label = text_dataset[0]
print(f"Sample token indices: {sample_text}")
print(f"Original text: {texts[0]}")
print(f"Label: {sample_label}")

# Create DataLoader
text_dataloader = DataLoader(text_dataset, batch_size=2, shuffle=True)

# Get a batch
for batch_texts, batch_labels in text_dataloader:
print(f"Batch shape: {batch_texts.shape}")
print(f"Labels: {batch_labels}")
break

Sample output:

Vocabulary size: 28
Sample token indices: tensor([19, 1, 16, 3, 12, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0])
Original text: This movie is amazing and I loved it!
Label: tensor(1)
Batch shape: torch.Size([2, 20])
Labels: tensor([0, 0])

Advanced Topic: Using Dataset with map-style and iterable-style

PyTorch supports two types of datasets: map-style and iterable-style. We've been focusing on map-style datasets, but let's briefly explore iterable-style datasets:

python
from torch.utils.data import IterableDataset

class StreamingDataset(IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end

def __iter__(self):
return iter(range(self.start, self.end))

# Create iterable dataset
stream_dataset = StreamingDataset(1, 1000)

# Create dataloader
stream_loader = DataLoader(stream_dataset, batch_size=5)

# Iterate over dataloader
for i, batch in enumerate(stream_loader):
print(f"Batch {i}: {batch}")
if i >= 2:
break

Sample output:

Batch 0: tensor([1, 2, 3, 4, 5])
Batch 1: tensor([6, 7, 8, 9, 10])
Batch 2: tensor([11, 12, 13, 14, 15])

Summary

The PyTorch Dataset class is a powerful abstraction that allows you to manage and preprocess your data efficiently. In this tutorial, we've covered:

  1. Creating basic PyTorch Datasets
  2. Working with different data types (numerical data, images, text)
  3. Applying transformations to your data
  4. Combining and subsetting datasets
  5. Creating a real-world text classification dataset
  6. Introduction to iterable-style datasets

With these tools, you can build custom datasets for any machine learning task, making your data handling code more organized and efficient.

Additional Resources

Exercises

  1. Create a custom Dataset to load and preprocess the MNIST dataset without using torchvision.datasets.MNIST
  2. Implement a Dataset that loads CSV data and applies different transformations to numerical features
  3. Build a Dataset for a time series prediction task that creates sequences of fixed length
  4. Create a Dataset that combines text and image data (e.g., for image captioning)
  5. Implement a custom sampling strategy using WeightedRandomSampler to handle imbalanced datasets

By mastering PyTorch's Dataset class, you'll be able to handle data efficiently for any machine learning task, from basic classification to complex multi-modal learning problems.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)