Skip to main content

TensorFlow Datasets

Introduction

TensorFlow Datasets (TFDS) is a collection of ready-to-use datasets for machine learning. It simplifies the process of loading, preparing, and utilizing data in your TensorFlow projects. Rather than manually downloading and preprocessing datasets, TFDS handles these steps for you, allowing you to focus on building and training your models.

In this tutorial, you'll learn:

  • How to install and import TensorFlow Datasets
  • Loading popular datasets from the TFDS catalog
  • Understanding dataset structure and features
  • Preparing datasets for training
  • Creating your own data pipelines using TFDS tools

Getting Started with TensorFlow Datasets

Installation

First, you need to install the tensorflow-datasets package:

bash
pip install tensorflow-datasets

Let's import the necessary libraries:

python
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

Loading Datasets from TFDS

TensorFlow Datasets provides access to hundreds of datasets across various categories such as image, text, audio, and more.

Basic Dataset Loading

The simplest way to load a dataset is using the tfds.load() function:

python
# Load the MNIST dataset
mnist_dataset = tfds.load(name='mnist', split='train')

This returns a tf.data.Dataset object that you can iterate through:

python
# Explore the first few examples
for example in mnist_dataset.take(2):
print(example)

Output:

{'image': <tf.Tensor: shape=(28, 28, 1), dtype=uint8, numpy=
array([[[0],
[0],
[0],
...,
[0],
[0],
[0]],
...
[[0],
[0],
[0],
...,
[0],
[0],
[0]]], dtype=uint8)>, 'label': <tf.Tensor: shape=(), dtype=int64, numpy=5>}
...

Dataset Splits

Most datasets come with predefined splits like 'train', 'test', and sometimes 'validation':

python
# Load specific splits
train_ds = tfds.load('mnist', split='train')
test_ds = tfds.load('mnist', split='test')

# You can also define custom splits
train_ds = tfds.load('mnist', split='train[:80%]')
validation_ds = tfds.load('mnist', split='train[80%:]')

Loading with as_supervised Parameter

Many datasets contain features and labels. To get them as tuples, use the as_supervised parameter:

python
# Load as (image, label) tuples
mnist_dataset = tfds.load('mnist', split='train', as_supervised=True)

# Now each example is a tuple
for image, label in mnist_dataset.take(1):
print(f"Image shape: {image.shape}, Label: {label.numpy()}")

Output:

Image shape: (28, 28, 1), Label: 5

Understanding Dataset Structure

Dataset Info

TFDS provides detailed information about each dataset:

python
# Get dataset info
mnist_info = tfds.builder('mnist').info
print(f"Dataset name: {mnist_info.name}")
print(f"Number of training examples: {mnist_info.splits['train'].num_examples}")
print(f"Number of test examples: {mnist_info.splits['test'].num_examples}")
print(f"Features: {mnist_info.features}")

Output:

Dataset name: mnist
Number of training examples: 60000
Number of test examples: 10000
Features: FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})

Visualizing Examples

Let's visualize some examples from an image dataset:

python
# Load the dataset with the data
mnist_dataset, mnist_info = tfds.load('mnist', split='train', with_info=True)

# Create a function to plot images
def plot_examples(dataset, num_examples=4):
plt.figure(figsize=(10, 10))

for i, example in enumerate(dataset.take(num_examples)):
image = example['image'].numpy()
label = example['label'].numpy()

plt.subplot(2, 2, i+1)
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"Label: {label}")
plt.axis('off')

plt.tight_layout()
plt.show()

# Plot 4 examples
plot_examples(mnist_dataset)

Preparing Datasets for Training

Data Preprocessing

Most models require preprocessing steps like normalization. Here's how to apply transformations:

python
def preprocess_mnist(image, label):
# Normalize pixel values
image = tf.cast(image, tf.float32) / 255.0
return image, label

# Apply preprocessing to the dataset
train_ds = tfds.load('mnist', split='train', as_supervised=True)
train_ds = train_ds.map(preprocess_mnist)

# Check the results
for image, label in train_ds.take(1):
print(f"Image min value: {tf.reduce_min(image)}, max value: {tf.reduce_max(image)}")

Output:

Image min value: 0.0, max value: 1.0

Batching and Shuffling

To prepare a dataset for training, we typically shuffle, batch, and prefetch data:

python
# Prepare a dataset for training
def prepare_for_training(ds, batch_size=32, shuffle_buffer_size=1000):
# Shuffle the data
ds = ds.shuffle(buffer_size=shuffle_buffer_size)

# Batch the data
ds = ds.batch(batch_size)

# Prefetch for performance
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

return ds

# Apply the preparation function
train_ds = tfds.load('mnist', split='train', as_supervised=True)
train_ds = train_ds.map(preprocess_mnist)
train_ds = prepare_for_training(train_ds)

# Check the batch structure
for images, labels in train_ds.take(1):
print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")

Output:

Batch shape: (32, 28, 28, 1)
Labels shape: (32,)

Real-world Examples

Example 1: Training a Simple CNN on MNIST

Let's put everything together and train a simple CNN on the MNIST dataset:

python
# Load and preprocess the dataset
train_ds = tfds.load('mnist', split='train', as_supervised=True)
test_ds = tfds.load('mnist', split='test', as_supervised=True)

def preprocess(image, label):
image = tf.cast(image, tf.float32) / 255.0
return image, label

train_ds = train_ds.map(preprocess).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
test_ds = test_ds.map(preprocess).batch(32)

# Create a simple CNN model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# Train the model
history = model.fit(
train_ds,
epochs=5,
validation_data=test_ds
)

# Evaluate the model
test_loss, test_accuracy = model.evaluate(test_ds)
print(f"Test accuracy: {test_accuracy:.4f}")

Output:

Epoch 1/5
1875/1875 [==============================] - 9s 4ms/step - loss: 0.1214 - accuracy: 0.9624 - val_loss: 0.0460 - val_accuracy: 0.9848
...
Epoch 5/5
1875/1875 [==============================] - 8s 4ms/step - loss: 0.0252 - accuracy: 0.9921 - val_loss: 0.0388 - val_accuracy: 0.9876
313/313 [==============================] - 1s 2ms/step - loss: 0.0388 - accuracy: 0.9876
Test accuracy: 0.9876

Example 2: Working with Image Classification Datasets

Let's see how to work with a more complex dataset like CIFAR-10:

python
# Load CIFAR-10 dataset
cifar10_ds, cifar10_info = tfds.load('cifar10', split='train', with_info=True)

# Display dataset info
print(f"Features: {cifar10_info.features}")
print(f"Number of classes: {cifar10_info.features['label'].num_classes}")
print(f"Class names: {cifar10_info.features['label'].names}")

# Display some examples
plt.figure(figsize=(12, 12))
for i, example in enumerate(cifar10_ds.take(9)):
image = example['image'].numpy()
label = example['label'].numpy()
label_name = cifar10_info.features['label'].names[label]

plt.subplot(3, 3, i+1)
plt.imshow(image)
plt.title(f"{label_name} ({label})")
plt.axis('off')

Creating Custom Data Pipelines

TensorFlow Datasets also provides tools for working with your own data:

python
# Example: Creating a dataset from NumPy arrays
# Generate synthetic data
num_examples = 1000
features = np.random.normal(size=(num_examples, 10)).astype(np.float32)
labels = np.random.randint(0, 2, size=(num_examples,)).astype(np.int32)

# Create a tf.data.Dataset
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

# Apply transformations
dataset = dataset.shuffle(1000).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

# Check the dataset
for feature_batch, label_batch in dataset.take(1):
print(f"Feature batch shape: {feature_batch.shape}")
print(f"Label batch shape: {label_batch.shape}")
print(f"First example features: {feature_batch[0]}")
print(f"First example label: {label_batch[0]}")

Output:

Feature batch shape: (32, 10)
Label batch shape: (32,)
First example features: [ 0.97548646 -0.7535481 -0.37163007 0.03770206 0.6645261 0.15464392
-1.8412154 0.45117098 -1.4589593 -0.77147275]
First example label: 1

Summary

TensorFlow Datasets provides a powerful and convenient way to work with datasets for machine learning. In this tutorial, you learned:

  • How to load and explore datasets using TFDS
  • Understanding dataset structure and metadata
  • Preparing datasets for model training through preprocessing, batching, and shuffling
  • Implementing real-world examples with image datasets
  • Creating your own data pipelines

With TFDS, you can focus more on model building and less on data handling logistics, making your machine learning workflow more efficient and productive.

Additional Resources

Exercises

  1. Load the Fashion MNIST dataset and train a model to classify clothing items.
  2. Use the CIFAR-100 dataset and explore its hierarchical class structure.
  3. Load a text dataset like IMDB reviews and prepare it for sentiment analysis.
  4. Create your own image dataset pipeline using folders of images and the tf.data.Dataset.list_files() method.
  5. Experiment with data augmentation techniques on an image dataset to improve model performance.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)