TensorFlow Datasets Library
Introduction
The TensorFlow Datasets (TFDS) library is an essential tool for machine learning practitioners that provides a collection of ready-to-use datasets for building, training, and evaluating machine learning models. It eliminates the need for manually downloading, extracting, and preprocessing data, which can be time-consuming and error-prone. TFDS handles all these steps automatically, allowing you to focus on model development rather than data preparation.
In this guide, we'll explore how to use TensorFlow Datasets to:
- Load popular datasets with just a few lines of code
- Explore and understand dataset structure
- Preprocess and transform datasets for training
- Create custom datasets when needed
Getting Started with TensorFlow Datasets
Installation
Before we begin, let's make sure TensorFlow Datasets is installed:
pip install tensorflow tensorflow-datasets
Basic Usage
Let's start with a simple example of loading a dataset:
import tensorflow as tf
import tensorflow_datasets as tfds
# Load the MNIST dataset
mnist_dataset = tfds.load(name="mnist", split="train")
# Convert the dataset to a format suitable for training
mnist_dataset = mnist_dataset.map(lambda example: (example['image'], example['label']))
mnist_dataset = mnist_dataset.batch(32)
# Iterate through the dataset
for images, labels in mnist_dataset.take(1):
print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
Output:
Batch shape: (32, 28, 28, 1)
Labels shape: (32,)
In this example, we loaded the MNIST dataset, which contains 28x28 grayscale images of handwritten digits, and prepared it for training by batching the data.
Exploring Available Datasets
TFDS provides access to hundreds of datasets across various domains like computer vision, natural language processing, audio processing, and more.
Listing Available Datasets
You can see all available datasets with:
import tensorflow_datasets as tfds
# List all available datasets
all_datasets = tfds.list_builders()
print(f"Total number of datasets: {len(all_datasets)}")
print(f"First 10 datasets: {all_datasets[:10]}")
Output:
Total number of datasets: 287
First 10 datasets: ['abstract_reasoning', 'aeslc', 'aflw2k3d', 'air_quality', 'amazon_us_reviews', 'animated_knots', 'answer_equivalence', 'aqua', 'arc', 'bair_robot_pushing_small']
Getting Dataset Information
To get more information about a specific dataset:
import tensorflow_datasets as tfds
# Get info about the CIFAR-10 dataset
cifar_info = tfds.builder("cifar10").info
print(f"Name: {cifar_info.name}")
print(f"Description: {cifar_info.description[:150]}...")
print(f"Number of classes: {cifar_info.features['label'].num_classes}")
print(f"Class names: {cifar_info.features['label'].names}")
print(f"Training examples: {cifar_info.splits['train'].num_examples}")
print(f"Test examples: {cifar_info.splits['test'].num_examples}")
Output:
Name: cifar10
Description: The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test...
Number of classes: 10
Class names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Training examples: 50000
Test examples: 10000
Loading and Preparing Datasets
Loading Different Splits
Datasets typically come with predefined splits like "train", "test", and sometimes "validation". You can load specific splits:
import tensorflow_datasets as tfds
# Load specific splits
train_ds = tfds.load("mnist", split="train[:80%]")
validation_ds = tfds.load("mnist", split="train[80%:90%]")
test_ds = tfds.load("mnist", split="train[90%:]")
print(f"Number of training batches: {len(list(train_ds))}")
print(f"Number of validation batches: {len(list(validation_ds))}")
print(f"Number of test batches: {len(list(test_ds))}")
Output:
Number of training batches: 48000
Number of validation batches: 6000
Number of test batches: 6000
Data Preprocessing
Most datasets need preprocessing before training. Here's how to normalize image data and one-hot encode labels:
import tensorflow as tf
import tensorflow_datasets as tfds
def preprocess_image(example):
# Normalize pixel values to [0, 1]
image = tf.cast(example['image'], tf.float32) / 255.0
# One-hot encode the label
label = tf.one_hot(example['label'], depth=10)
return image, label
# Load and preprocess MNIST
mnist_ds = tfds.load("mnist", split="train[:1000]")
mnist_ds = mnist_ds.map(preprocess_image)
mnist_ds = mnist_ds.batch(32).prefetch(tf.data.AUTOTUNE)
# Check the processed data
for images, labels in mnist_ds.take(1):
print(f"Image batch shape: {images.shape}")
print(f"Image value range: {tf.reduce_min(images).numpy()} to {tf.reduce_max(images).numpy()}")
print(f"Label batch shape: {labels.shape}")
print(f"First label (one-hot encoded): {labels[0]}")
Output:
Image batch shape: (32, 28, 28, 1)
Image value range: 0.0 to 1.0
Label batch shape: (32, 10)
First label (one-hot encoded): [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
Working with Different Types of Datasets
Image Datasets
Let's load and display images from the CIFAR-10 dataset:
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
# Load CIFAR-10
cifar10_ds = tfds.load("cifar10", split="train[:10]")
# Create a figure to display images
plt.figure(figsize=(10, 5))
for i, example in enumerate(cifar10_ds.take(10)):
image = example["image"]
label = example["label"].numpy()
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
class_name = class_names[label]
plt.subplot(2, 5, i+1)
plt.imshow(image)
plt.title(class_name)
plt.axis('off')
plt.tight_layout()
plt.show()
Text Datasets
TFDS also provides text datasets. Let's explore the IMDb movie reviews dataset:
import tensorflow_datasets as tfds
# Load IMDb dataset
imdb_ds = tfds.load("imdb_reviews", split="train[:5]")
# Print a few examples
for i, example in enumerate(imdb_ds.take(5)):
text = example["text"].numpy().decode("utf-8")
label = "Positive" if example["label"].numpy() == 1 else "Negative"
print(f"\nReview {i+1} ({label}):")
print(f"{text[:200]}...")
Output:
Review 1 (Positive):
This was an absolutely incredible movie. Don Cheadle delivers an Oscar worthy performance as the hotel manager of the Hotel Des Milles Collines. Nick Nolte and Joaquin Phoenix are also excellent in their roles as a UN colonel and...
Review 2 (Negative):
I really, really, really, disliked this movie. The characters were poorly developed, the plot was boring and predictable, and the dialogue was stilted and unrealistic. I couldn't wait for it to end...
[Output continues for 5 reviews]
Advanced Features
Using as_supervised=True
For datasets with clear input-target pairs, you can use the as_supervised
parameter:
import tensorflow_datasets as tfds
# Load the dataset with as_supervised=True
mnist_ds = tfds.load("mnist", split="train[:100]", as_supervised=True)
# Now the dataset yields (image, label) tuples directly
for image, label in mnist_ds.take(1):
print(f"Image shape: {image.shape}")
print(f"Label: {label.numpy()}")
Output:
Image shape: (28, 28, 1)
Label: 7
Data Augmentation
Let's apply simple data augmentation to images:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
# Load a few CIFAR images
cifar_ds = tfds.load("cifar10", split="train[:4]", as_supervised=True)
# Define an augmentation function
def augment(image, label):
# Random flip left-right
image = tf.image.random_flip_left_right(image)
# Random brightness adjustment
image = tf.image.random_brightness(image, 0.2)
# Random contrast adjustment
image = tf.image.random_contrast(image, 0.8, 1.2)
return image, label
# Create augmented dataset
augmented_ds = cifar_ds.map(augment)
# Display original and augmented images
plt.figure(figsize=(10, 5))
for i, ((orig_img, label), (aug_img, _)) in enumerate(zip(cifar_ds, augmented_ds)):
# Original image
plt.subplot(2, 4, i+1)
plt.imshow(orig_img.numpy())
plt.title(f"Original {i+1}")
plt.axis('off')
# Augmented image
plt.subplot(2, 4, i+5)
plt.imshow(tf.clip_by_value(aug_img, 0, 255).numpy().astype('uint8'))
plt.title(f"Augmented {i+1}")
plt.axis('off')
plt.tight_layout()
plt.show()
Creating Custom Datasets
Sometimes, you might need to work with your own datasets. TFDS provides ways to create custom datasets.
Using tf.data.Dataset.from_tensor_slices()
For small datasets that fit in memory:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Create synthetic data
x = np.linspace(-2, 2, 200)
y = x**2 + 0.1 * np.random.randn(200)
# Create a TensorFlow dataset
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.batch(32)
# Plot the data
plt.figure(figsize=(8, 6))
for batch_x, batch_y in dataset.take(1):
plt.scatter(batch_x, batch_y, alpha=0.6)
plt.title("Custom Dataset Example: Quadratic Function with Noise")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True)
plt.show()
# Use the dataset for training
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(1,)),
tf.keras.layers.Dense(10, activation='relu'),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
model.fit(dataset, epochs=5)
Creating a Dataset from Files
For datasets stored in files:
import tensorflow as tf
import pathlib
import numpy as np
import matplotlib.pyplot as plt
# Create some sample CSV files
data_dir = pathlib.Path("sample_data")
data_dir.mkdir(exist_ok=True)
# Generate 3 CSV files with random data
for i in range(3):
x = np.random.normal(i, 1, 100)
y = 2 * x + i + 0.1 * np.random.randn(100)
data = np.column_stack([x, y])
np.savetxt(data_dir / f"data_{i}.csv", data, delimiter=",")
# Function to parse CSV files
def parse_csv(line):
fields = tf.io.decode_csv(line, record_defaults=[[0.0], [0.0]])
x = fields[0]
y = fields[1]
return x, y
# Create a dataset from the CSV files
file_pattern = str(data_dir / "*.csv")
file_dataset = tf.data.Dataset.list_files(file_pattern)
# Process each file
dataset = file_dataset.interleave(
lambda filepath: tf.data.TextLineDataset(filepath).map(parse_csv),
cycle_length=3
)
# Batch the dataset
dataset = dataset.batch(32)
# Display a batch of data
plt.figure(figsize=(8, 6))
for x_batch, y_batch in dataset.take(1):
plt.scatter(x_batch, y_batch, alpha=0.6)
plt.title("Dataset from CSV Files")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True)
plt.show()
Real-World Example: Image Classification
Let's create a complete example of using TFDS for image classification:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
# Load and prepare the CIFAR-10 dataset
(train_ds, test_ds), ds_info = tfds.load(
'cifar10',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
# Get class names
class_names = ds_info.features['label'].names
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255.0, label
# Prepare training dataset
train_ds = train_ds.map(normalize_img)
train_ds = train_ds.cache()
train_ds = train_ds.shuffle(ds_info.splits['train'].num_examples)
train_ds = train_ds.batch(128)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
# Prepare test dataset
test_ds = test_ds.map(normalize_img)
test_ds = test_ds.batch(128)
test_ds = test_ds.cache()
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)
# Create a model
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
# Compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
# Train the model
history = model.fit(
train_ds,
epochs=5,
validation_data=test_ds,
)
# Plot the results
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['sparse_categorical_accuracy'], label='Training Accuracy')
plt.plot(history.history['val_sparse_categorical_accuracy'], label='Validation Accuracy')
plt.title('Accuracies')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()
Summary
The TensorFlow Datasets (TFDS) library is a powerful tool that simplifies the process of accessing and working with a wide variety of datasets for machine learning projects. In this guide, we explored:
- How to load and explore pre-built datasets
- Understanding dataset structure and metadata
- Preprocessing and transforming data for training
- Creating data pipelines with batching, shuffling, and prefetching
- Data augmentation techniques
- Creating custom datasets
- Building a complete image classification pipeline
By leveraging TFDS, you can focus more on building and refining models rather than spending time on data preparation and management.
Additional Resources
- TensorFlow Datasets Official Documentation
- TensorFlow Data Loading Guide
- Catalog of available datasets in TFDS
Exercises
- Load the "fashion_mnist" dataset and build a simple classifier to distinguish between different clothing items.
- Create a data augmentation pipeline for the CIFAR-10 dataset and compare the performance of a model trained with and without augmentation.
- Load a text dataset like "imdb_reviews" and build a sentiment analysis model.
- Create a custom dataset from a collection of your own images and use it to train a model.
- Explore the
with_info
parameter intfds.load()
and extract more metadata from a dataset of your choice.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)