Skip to main content

PyTorch Sampling

Introduction

Sampling is a crucial aspect of working with data in machine learning. It involves selecting a subset of data from a larger dataset or generating new data points from probability distributions. In PyTorch, sampling techniques are essential for various tasks such as:

  • Creating balanced training batches
  • Implementing data augmentation strategies
  • Generating synthetic data
  • Working with different sampling strategies for better model training

In this tutorial, we'll explore how to perform different types of sampling operations in PyTorch, starting with basic concepts and progressing to more advanced techniques.

Basic Sampling in PyTorch

Random Sampling

The most basic form of sampling is random sampling, where we select data points randomly from a dataset.

python
import torch
from torch.utils.data import Dataset, DataLoader, Subset, RandomSampler

# Create a simple dataset
class SimpleDataset(Dataset):
def __init__(self, size=1000):
self.data = torch.randn(size, 10)
self.targets = torch.randint(0, 2, (size,))

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx], self.targets[idx]

# Create dataset instance
dataset = SimpleDataset()

# Random sampling with replacement
random_sampler = RandomSampler(dataset, replacement=True, num_samples=100)

# Create DataLoader with the sampler
random_loader = DataLoader(dataset, batch_size=10, sampler=random_sampler)

# Example usage
for batch_idx, (data, target) in enumerate(random_loader):
print(f"Batch {batch_idx}: Shape = {data.shape}, Target shape = {target.shape}")
if batch_idx == 2: # Just print first few batches
break

Output:

Batch 0: Shape = torch.Size([10, 10]), Target shape = torch.Size([10])
Batch 1: Shape = torch.Size([10, 10]), Target shape = torch.Size([10])
Batch 2: Shape = torch.Size([10, 10]), Target shape = torch.Size([10])

Subset Sampling

Sometimes you might want to create a fixed subset of your dataset:

python
import numpy as np

# Create indices for a subset
indices = np.random.choice(len(dataset), size=500, replace=False)
subset = Subset(dataset, indices)

# Create DataLoader for the subset
subset_loader = DataLoader(subset, batch_size=10, shuffle=True)

print(f"Original dataset size: {len(dataset)}")
print(f"Subset size: {len(subset)}")

Output:

Original dataset size: 1000
Subset size: 500

Weighted and Stratified Sampling

Weighted Sampling

Weighted sampling allows you to sample elements based on a probability distribution. This is useful when dealing with imbalanced datasets.

python
from torch.utils.data import WeightedRandomSampler

# Get all targets from dataset
all_targets = torch.tensor([dataset[i][1] for i in range(len(dataset))])

# Count samples per class
class_count = torch.bincount(all_targets)
print(f"Samples per class: {class_count}")

# Calculate class weights (inverse frequency)
class_weights = 1.0 / class_count.float()
print(f"Class weights: {class_weights}")

# Assign weight to each sample
sample_weights = class_weights[all_targets]

# Create weighted sampler
weighted_sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)

# Create DataLoader with the weighted sampler
weighted_loader = DataLoader(
dataset,
batch_size=10,
sampler=weighted_sampler
)

# Verify the balance
sampled_targets = []
for _, (_, targets) in enumerate(weighted_loader):
sampled_targets.append(targets)
if len(sampled_targets) > 10: # Just sample a few batches
break

sampled_targets = torch.cat(sampled_targets)
balanced_count = torch.bincount(sampled_targets)
print(f"Class distribution after weighted sampling: {balanced_count}")

Stratified Sampling for Cross-Validation

For cross-validation, it's important to maintain the class distribution across folds. Here's how to implement stratified sampling:

python
from sklearn.model_selection import StratifiedKFold
import numpy as np

# Get all targets as numpy array
targets_np = all_targets.numpy()

# Create stratified k-fold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Print the fold sizes
for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets_np)), targets_np)):
train_targets = targets_np[train_idx]
val_targets = targets_np[val_idx]

train_class_dist = np.bincount(train_targets)
val_class_dist = np.bincount(val_targets)

print(f"Fold {fold+1}:")
print(f" Training samples: {len(train_idx)}, Class distribution: {train_class_dist}")
print(f" Validation samples: {len(val_idx)}, Class distribution: {val_class_dist}")

Batch Sampling Techniques

Sequential and Random Batch Sampling

PyTorch's DataLoader provides two main ways to create batches:

python
# Sequential batches (shuffle=False)
sequential_loader = DataLoader(
dataset,
batch_size=32,
shuffle=False
)

# Random batches (shuffle=True)
random_loader = DataLoader(
dataset,
batch_size=32,
shuffle=True
)

Custom Batch Sampler

For more complex batch sampling patterns, you can create a custom batch sampler:

python
from torch.utils.data import BatchSampler

# Create a custom batch sampler
class BalancedBatchSampler(BatchSampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
self.targets = torch.tensor([dataset[i][1] for i in range(len(dataset))])
self.indices_per_class = {
class_idx: torch.where(self.targets == class_idx)[0]
for class_idx in torch.unique(self.targets)
}
self.num_classes = len(self.indices_per_class)
self.samples_per_class = batch_size // self.num_classes

def __iter__(self):
indices = []

# Repeat until we run out of samples
while sum(len(indices_class) for indices_class in self.indices_per_class.values()) > 0:
batch = []

# Add samples_per_class samples from each class
for class_idx in self.indices_per_class.keys():
available = self.indices_per_class[class_idx]
if len(available) > 0:
# Get the indices we'll use for this class
to_add = available[:min(self.samples_per_class, len(available))]
batch.extend(to_add.tolist())

# Remove these indices from available
self.indices_per_class[class_idx] = available[min(self.samples_per_class, len(available)):]

if batch:
yield batch
else:
break

def __len__(self):
return (len(self.dataset) + self.batch_size - 1) // self.batch_size

# Create a DataLoader with our custom batch sampler
balanced_batch_sampler = BalancedBatchSampler(dataset, batch_size=10)
balanced_loader = DataLoader(dataset, batch_sampler=balanced_batch_sampler)

# Check the class distribution in the first few batches
for batch_idx, (_, targets) in enumerate(balanced_loader):
print(f"Batch {batch_idx}: Class distribution = {torch.bincount(targets)}")
if batch_idx == 2: # Just print first few batches
break

Sampling from Probability Distributions

PyTorch also provides functions to sample from standard probability distributions.

Uniform and Normal Distributions

python
# Sample from uniform distribution
uniform_samples = torch.rand(5, 3) # 5 samples, each with 3 dimensions
print("Uniform samples (0-1):")
print(uniform_samples)

# Sample from normal distribution
normal_samples = torch.randn(5, 3) # 5 samples, each with 3 dimensions
print("\nNormal samples (mean=0, std=1):")
print(normal_samples)

# Custom range uniform distribution
custom_uniform = torch.rand(5, 3) * 10 - 5 # Uniform from -5 to 5
print("\nCustom uniform samples (-5 to 5):")
print(custom_uniform)

# Custom normal distribution
custom_normal = torch.randn(5, 3) * 2 + 1 # Mean = 1, Std = 2
print("\nCustom normal samples (mean=1, std=2):")
print(custom_normal)

Categorical and Multinomial Sampling

For discrete distributions:

python
# Define probabilities for 3 categories
probs = torch.tensor([0.2, 0.3, 0.5])

# Sample a single categorical value
categorical_sample = torch.multinomial(probs, num_samples=1)
print(f"Categorical sample: {categorical_sample.item()}")

# Sample multiple categorical values with replacement
multi_samples = torch.multinomial(probs, num_samples=10, replacement=True)
print(f"Multiple samples: {multi_samples}")

# Count occurrences of each category
sample_counts = torch.bincount(multi_samples, minlength=len(probs))
print(f"Sample counts: {sample_counts}")
print(f"Empirical probabilities: {sample_counts / sample_counts.sum()}")
print(f"True probabilities: {probs}")

Real-world Application: Sampling for Imbalanced Classification

Let's implement a complete example of using weighted sampling for an imbalanced classification problem:

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
import matplotlib.pyplot as plt

# Generate an imbalanced dataset (binary classification)
def create_imbalanced_dataset(n_samples=1000, imbalance_ratio=0.1):
# Features: 2D data points
X = torch.randn(n_samples, 2)

# Create imbalanced labels (class 1 is minority)
n_minority = int(n_samples * imbalance_ratio)
y = torch.zeros(n_samples)
minority_indices = torch.randperm(n_samples)[:n_minority]
y[minority_indices] = 1

# Make minority class separable by shifting it
X[y == 1, :] += torch.tensor([3.0, 3.0])

return X, y

# Create dataset
X, y = create_imbalanced_dataset(imbalance_ratio=0.05)
dataset = TensorDataset(X, y)

# Visualize the dataset
plt.figure(figsize=(8, 6))
plt.scatter(X[y == 0, 0].numpy(), X[y == 0, 1].numpy(), alpha=0.5, label='Class 0 (Majority)')
plt.scatter(X[y == 1, 0].numpy(), X[y == 1, 1].numpy(), alpha=0.5, label='Class 1 (Minority)')
plt.legend()
plt.title('Imbalanced Dataset')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.savefig('imbalanced_dataset.png')

# Compare regular sampling vs weighted sampling
print(f"Class distribution in dataset: {torch.bincount(y.long())}")

# Regular DataLoader
regular_loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Weighted sampling
# Calculate weights for imbalanced dataset
class_counts = torch.bincount(y.long())
class_weights = 1.0 / class_counts.float()
sample_weights = class_weights[y.long()]

weighted_sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(y),
replacement=True
)

weighted_loader = DataLoader(
dataset,
batch_size=64,
sampler=weighted_sampler
)

# Compare distributions
def check_distribution(loader, name):
sampled_targets = []
for _, (_, targets) in enumerate(loader):
sampled_targets.append(targets)
sampled_targets = torch.cat(sampled_targets)
class_dist = torch.bincount(sampled_targets.long())
print(f"{name} sampling distribution: {class_dist}")

check_distribution(regular_loader, "Regular")
check_distribution(weighted_loader, "Weighted")

# Define a simple model for the classification task
class SimpleClassifier(nn.Module):
def __init__(self):
super(SimpleClassifier, self).__init__()
self.layers = nn.Sequential(
nn.Linear(2, 16),
nn.ReLU(),
nn.Linear(16, 16),
nn.ReLU(),
nn.Linear(16, 1),
nn.Sigmoid()
)

def forward(self, x):
return self.layers(x).squeeze()

# Function to train the model
def train_model(model, dataloader, epochs=10):
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCELoss()

losses = []

for epoch in range(epochs):
epoch_loss = 0
for X_batch, y_batch in dataloader:
# Forward pass
y_pred = model(X_batch)
loss = criterion(y_pred, y_batch)

# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

epoch_loss += loss.item()

losses.append(epoch_loss / len(dataloader))
print(f"Epoch {epoch+1}/{epochs}, Loss: {losses[-1]:.4f}")

return losses

# Train with regular sampling
regular_model = SimpleClassifier()
regular_losses = train_model(regular_model, regular_loader)

# Train with weighted sampling
weighted_model = SimpleClassifier()
weighted_losses = train_model(weighted_model, weighted_loader)

# Function to evaluate model
def evaluate_model(model, X, y):
model.eval()
with torch.no_grad():
y_pred = (model(X) > 0.5).float()
accuracy = (y_pred == y).float().mean()

# Calculate per-class accuracy
class0_idx = (y == 0)
class1_idx = (y == 1)

acc_class0 = (y_pred[class0_idx] == y[class0_idx]).float().mean() if class0_idx.any() else 0
acc_class1 = (y_pred[class1_idx] == y[class1_idx]).float().mean() if class1_idx.any() else 0

return accuracy.item(), acc_class0.item(), acc_class1.item()

# Evaluate both models
reg_acc, reg_acc0, reg_acc1 = evaluate_model(regular_model, X, y)
weighted_acc, weighted_acc0, weighted_acc1 = evaluate_model(weighted_model, X, y)

print("\nModel Performance:")
print(f"Regular Sampling: Overall Acc = {reg_acc:.4f}, Class 0 = {reg_acc0:.4f}, Class 1 = {reg_acc1:.4f}")
print(f"Weighted Sampling: Overall Acc = {weighted_acc:.4f}, Class 0 = {weighted_acc0:.4f}, Class 1 = {weighted_acc1:.4f}")

This example demonstrates how weighted sampling can help improve model performance on imbalanced datasets by ensuring the model sees enough examples from the minority class during training.

Summary

In this tutorial, we covered different sampling techniques in PyTorch:

  • Random and subset sampling with PyTorch's built-in samplers
  • Weighted sampling to handle imbalanced datasets
  • Stratified sampling for creating representative splits
  • Custom batch sampling for specific sampling patterns
  • Sampling from probability distributions
  • A practical application of sampling for imbalanced classification

Effective sampling strategies can significantly improve model performance, especially in scenarios with limited or imbalanced data. By using the appropriate sampling technique, you can make better use of your available data and train more robust models.

Exercises

  1. Modify the BalancedBatchSampler to handle datasets with more than two classes
  2. Create a custom sampler that implements curriculum learning (starting with easier samples and gradually introducing harder ones)
  3. Implement a version of weighted sampling that changes the sampling weights during training based on model performance
  4. Create a visualization that shows how different sampling strategies affect the class distribution in batches
  5. Experiment with different weighted sampling strategies for the imbalanced classification problem and compare their performance

Additional Resources



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)