Skip to main content

TensorFlow Datasets Library

Introduction

The TensorFlow Datasets (TFDS) library is an essential tool for machine learning practitioners that provides a collection of ready-to-use datasets for building, training, and evaluating machine learning models. It eliminates the need for manually downloading, extracting, and preprocessing data, which can be time-consuming and error-prone. TFDS handles all these steps automatically, allowing you to focus on model development rather than data preparation.

In this guide, we'll explore how to use TensorFlow Datasets to:

  • Load popular datasets with just a few lines of code
  • Explore and understand dataset structure
  • Preprocess and transform datasets for training
  • Create custom datasets when needed

Getting Started with TensorFlow Datasets

Installation

Before we begin, let's make sure TensorFlow Datasets is installed:

bash
pip install tensorflow tensorflow-datasets

Basic Usage

Let's start with a simple example of loading a dataset:

python
import tensorflow as tf
import tensorflow_datasets as tfds

# Load the MNIST dataset
mnist_dataset = tfds.load(name="mnist", split="train")

# Convert the dataset to a format suitable for training
mnist_dataset = mnist_dataset.map(lambda example: (example['image'], example['label']))
mnist_dataset = mnist_dataset.batch(32)

# Iterate through the dataset
for images, labels in mnist_dataset.take(1):
print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")

Output:

Batch shape: (32, 28, 28, 1)
Labels shape: (32,)

In this example, we loaded the MNIST dataset, which contains 28x28 grayscale images of handwritten digits, and prepared it for training by batching the data.

Exploring Available Datasets

TFDS provides access to hundreds of datasets across various domains like computer vision, natural language processing, audio processing, and more.

Listing Available Datasets

You can see all available datasets with:

python
import tensorflow_datasets as tfds

# List all available datasets
all_datasets = tfds.list_builders()
print(f"Total number of datasets: {len(all_datasets)}")
print(f"First 10 datasets: {all_datasets[:10]}")

Output:

Total number of datasets: 287
First 10 datasets: ['abstract_reasoning', 'aeslc', 'aflw2k3d', 'air_quality', 'amazon_us_reviews', 'animated_knots', 'answer_equivalence', 'aqua', 'arc', 'bair_robot_pushing_small']

Getting Dataset Information

To get more information about a specific dataset:

python
import tensorflow_datasets as tfds

# Get info about the CIFAR-10 dataset
cifar_info = tfds.builder("cifar10").info
print(f"Name: {cifar_info.name}")
print(f"Description: {cifar_info.description[:150]}...")
print(f"Number of classes: {cifar_info.features['label'].num_classes}")
print(f"Class names: {cifar_info.features['label'].names}")
print(f"Training examples: {cifar_info.splits['train'].num_examples}")
print(f"Test examples: {cifar_info.splits['test'].num_examples}")

Output:

Name: cifar10
Description: The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test...
Number of classes: 10
Class names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
Training examples: 50000
Test examples: 10000

Loading and Preparing Datasets

Loading Different Splits

Datasets typically come with predefined splits like "train", "test", and sometimes "validation". You can load specific splits:

python
import tensorflow_datasets as tfds

# Load specific splits
train_ds = tfds.load("mnist", split="train[:80%]")
validation_ds = tfds.load("mnist", split="train[80%:90%]")
test_ds = tfds.load("mnist", split="train[90%:]")

print(f"Number of training batches: {len(list(train_ds))}")
print(f"Number of validation batches: {len(list(validation_ds))}")
print(f"Number of test batches: {len(list(test_ds))}")

Output:

Number of training batches: 48000
Number of validation batches: 6000
Number of test batches: 6000

Data Preprocessing

Most datasets need preprocessing before training. Here's how to normalize image data and one-hot encode labels:

python
import tensorflow as tf
import tensorflow_datasets as tfds

def preprocess_image(example):
# Normalize pixel values to [0, 1]
image = tf.cast(example['image'], tf.float32) / 255.0

# One-hot encode the label
label = tf.one_hot(example['label'], depth=10)

return image, label

# Load and preprocess MNIST
mnist_ds = tfds.load("mnist", split="train[:1000]")
mnist_ds = mnist_ds.map(preprocess_image)
mnist_ds = mnist_ds.batch(32).prefetch(tf.data.AUTOTUNE)

# Check the processed data
for images, labels in mnist_ds.take(1):
print(f"Image batch shape: {images.shape}")
print(f"Image value range: {tf.reduce_min(images).numpy()} to {tf.reduce_max(images).numpy()}")
print(f"Label batch shape: {labels.shape}")
print(f"First label (one-hot encoded): {labels[0]}")

Output:

Image batch shape: (32, 28, 28, 1)
Image value range: 0.0 to 1.0
Label batch shape: (32, 10)
First label (one-hot encoded): [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

Working with Different Types of Datasets

Image Datasets

Let's load and display images from the CIFAR-10 dataset:

python
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# Load CIFAR-10
cifar10_ds = tfds.load("cifar10", split="train[:10]")

# Create a figure to display images
plt.figure(figsize=(10, 5))
for i, example in enumerate(cifar10_ds.take(10)):
image = example["image"]
label = example["label"].numpy()

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
class_name = class_names[label]

plt.subplot(2, 5, i+1)
plt.imshow(image)
plt.title(class_name)
plt.axis('off')

plt.tight_layout()
plt.show()

Text Datasets

TFDS also provides text datasets. Let's explore the IMDb movie reviews dataset:

python
import tensorflow_datasets as tfds

# Load IMDb dataset
imdb_ds = tfds.load("imdb_reviews", split="train[:5]")

# Print a few examples
for i, example in enumerate(imdb_ds.take(5)):
text = example["text"].numpy().decode("utf-8")
label = "Positive" if example["label"].numpy() == 1 else "Negative"

print(f"\nReview {i+1} ({label}):")
print(f"{text[:200]}...")

Output:

Review 1 (Positive):
This was an absolutely incredible movie. Don Cheadle delivers an Oscar worthy performance as the hotel manager of the Hotel Des Milles Collines. Nick Nolte and Joaquin Phoenix are also excellent in their roles as a UN colonel and...

Review 2 (Negative):
I really, really, really, disliked this movie. The characters were poorly developed, the plot was boring and predictable, and the dialogue was stilted and unrealistic. I couldn't wait for it to end...

[Output continues for 5 reviews]

Advanced Features

Using as_supervised=True

For datasets with clear input-target pairs, you can use the as_supervised parameter:

python
import tensorflow_datasets as tfds

# Load the dataset with as_supervised=True
mnist_ds = tfds.load("mnist", split="train[:100]", as_supervised=True)

# Now the dataset yields (image, label) tuples directly
for image, label in mnist_ds.take(1):
print(f"Image shape: {image.shape}")
print(f"Label: {label.numpy()}")

Output:

Image shape: (28, 28, 1)
Label: 7

Data Augmentation

Let's apply simple data augmentation to images:

python
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# Load a few CIFAR images
cifar_ds = tfds.load("cifar10", split="train[:4]", as_supervised=True)

# Define an augmentation function
def augment(image, label):
# Random flip left-right
image = tf.image.random_flip_left_right(image)
# Random brightness adjustment
image = tf.image.random_brightness(image, 0.2)
# Random contrast adjustment
image = tf.image.random_contrast(image, 0.8, 1.2)
return image, label

# Create augmented dataset
augmented_ds = cifar_ds.map(augment)

# Display original and augmented images
plt.figure(figsize=(10, 5))
for i, ((orig_img, label), (aug_img, _)) in enumerate(zip(cifar_ds, augmented_ds)):
# Original image
plt.subplot(2, 4, i+1)
plt.imshow(orig_img.numpy())
plt.title(f"Original {i+1}")
plt.axis('off')

# Augmented image
plt.subplot(2, 4, i+5)
plt.imshow(tf.clip_by_value(aug_img, 0, 255).numpy().astype('uint8'))
plt.title(f"Augmented {i+1}")
plt.axis('off')

plt.tight_layout()
plt.show()

Creating Custom Datasets

Sometimes, you might need to work with your own datasets. TFDS provides ways to create custom datasets.

Using tf.data.Dataset.from_tensor_slices()

For small datasets that fit in memory:

python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Create synthetic data
x = np.linspace(-2, 2, 200)
y = x**2 + 0.1 * np.random.randn(200)

# Create a TensorFlow dataset
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.batch(32)

# Plot the data
plt.figure(figsize=(8, 6))
for batch_x, batch_y in dataset.take(1):
plt.scatter(batch_x, batch_y, alpha=0.6)
plt.title("Custom Dataset Example: Quadratic Function with Noise")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True)
plt.show()

# Use the dataset for training
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(1,)),
tf.keras.layers.Dense(10, activation='relu'),
tf.keras.layers.Dense(1)
])

model.compile(optimizer='adam', loss='mse')
model.fit(dataset, epochs=5)

Creating a Dataset from Files

For datasets stored in files:

python
import tensorflow as tf
import pathlib
import numpy as np
import matplotlib.pyplot as plt

# Create some sample CSV files
data_dir = pathlib.Path("sample_data")
data_dir.mkdir(exist_ok=True)

# Generate 3 CSV files with random data
for i in range(3):
x = np.random.normal(i, 1, 100)
y = 2 * x + i + 0.1 * np.random.randn(100)
data = np.column_stack([x, y])
np.savetxt(data_dir / f"data_{i}.csv", data, delimiter=",")

# Function to parse CSV files
def parse_csv(line):
fields = tf.io.decode_csv(line, record_defaults=[[0.0], [0.0]])
x = fields[0]
y = fields[1]
return x, y

# Create a dataset from the CSV files
file_pattern = str(data_dir / "*.csv")
file_dataset = tf.data.Dataset.list_files(file_pattern)

# Process each file
dataset = file_dataset.interleave(
lambda filepath: tf.data.TextLineDataset(filepath).map(parse_csv),
cycle_length=3
)

# Batch the dataset
dataset = dataset.batch(32)

# Display a batch of data
plt.figure(figsize=(8, 6))
for x_batch, y_batch in dataset.take(1):
plt.scatter(x_batch, y_batch, alpha=0.6)
plt.title("Dataset from CSV Files")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True)
plt.show()

Real-World Example: Image Classification

Let's create a complete example of using TFDS for image classification:

python
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# Load and prepare the CIFAR-10 dataset
(train_ds, test_ds), ds_info = tfds.load(
'cifar10',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)

# Get class names
class_names = ds_info.features['label'].names

def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255.0, label

# Prepare training dataset
train_ds = train_ds.map(normalize_img)
train_ds = train_ds.cache()
train_ds = train_ds.shuffle(ds_info.splits['train'].num_examples)
train_ds = train_ds.batch(128)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

# Prepare test dataset
test_ds = test_ds.map(normalize_img)
test_ds = test_ds.batch(128)
test_ds = test_ds.cache()
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)

# Create a model
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])

# Compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Train the model
history = model.fit(
train_ds,
epochs=5,
validation_data=test_ds,
)

# Plot the results
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['sparse_categorical_accuracy'], label='Training Accuracy')
plt.plot(history.history['val_sparse_categorical_accuracy'], label='Validation Accuracy')
plt.title('Accuracies')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

Summary

The TensorFlow Datasets (TFDS) library is a powerful tool that simplifies the process of accessing and working with a wide variety of datasets for machine learning projects. In this guide, we explored:

  • How to load and explore pre-built datasets
  • Understanding dataset structure and metadata
  • Preprocessing and transforming data for training
  • Creating data pipelines with batching, shuffling, and prefetching
  • Data augmentation techniques
  • Creating custom datasets
  • Building a complete image classification pipeline

By leveraging TFDS, you can focus more on building and refining models rather than spending time on data preparation and management.

Additional Resources

Exercises

  1. Load the "fashion_mnist" dataset and build a simple classifier to distinguish between different clothing items.
  2. Create a data augmentation pipeline for the CIFAR-10 dataset and compare the performance of a model trained with and without augmentation.
  3. Load a text dataset like "imdb_reviews" and build a sentiment analysis model.
  4. Create a custom dataset from a collection of your own images and use it to train a model.
  5. Explore the with_info parameter in tfds.load() and extract more metadata from a dataset of your choice.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)