Skip to main content

TensorFlow TFRecord

Introduction

When working with large datasets in machine learning applications, efficient data storage and retrieval becomes crucial for performance. TensorFlow's TFRecord format provides a solution by offering a simple binary file format that allows you to efficiently store and read sequence of binary records.

TFRecord offers several advantages:

  • Efficient storage: Data is stored in a binary format, which is more space-efficient than text formats
  • Improved I/O performance: Reading binary data is faster than parsing text formats
  • Atomic records: Each example is stored as a self-contained record, making it easy to shuffle and batch data
  • Cross-platform compatibility: Works consistently across different operating systems and environments

In this tutorial, we'll learn how to create, write, and read TFRecord files in TensorFlow, and understand when and why you should use this format.

Understanding TFRecord Format

TFRecord is TensorFlow's binary file format designed specifically for storing a sequence of binary records. Each record in a TFRecord file contains:

  1. Length: The size of the data in bytes
  2. CRC-32C checksum of the length (for data integrity)
  3. Data: The actual serialized data (typically a serialized tf.train.Example protocol buffer)
  4. CRC-32C checksum of the data (for data integrity)

Creating TFRecord Files

Let's start by creating a simple TFRecord file with some basic data.

Step 1: Import necessary libraries

python
import tensorflow as tf
import numpy as np

Step 2: Create example data

Let's create some simple data that we'll store in our TFRecord file:

python
# Sample data: 3 examples with features
feature_data = [
{
'feature1': 10.5,
'feature2': [1, 2, 3],
'feature3': 'Hello'
},
{
'feature1': -5.0,
'feature2': [4, 5, 6],
'feature3': 'World'
},
{
'feature1': 7.2,
'feature2': [7, 8, 9],
'feature3': 'TFRecord'
}
]

Step 3: Convert data to TensorFlow Examples

TFRecord files typically store data as serialized tf.train.Example protocol buffers. We need to convert our Python data to this format:

python
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode('utf-8')]))

def create_example(data_point):
feature = {
'feature1': _float_feature(data_point['feature1']),
'feature2': _int64_feature(data_point['feature2']),
'feature3': _bytes_feature(data_point['feature3'])
}
return tf.train.Example(features=tf.train.Features(feature=feature))

Step 4: Write data to a TFRecord file

Now, let's write our examples to a TFRecord file:

python
filename = 'example.tfrecord'
with tf.io.TFRecordWriter(filename) as writer:
for data_point in feature_data:
example = create_example(data_point)
writer.write(example.SerializeToString())

print(f"TFRecord file saved at: {filename}")

Output:

TFRecord file saved at: example.tfrecord

Reading TFRecord Files

Now that we've created a TFRecord file, let's see how to read the data back.

Step 1: Create a feature description

When reading TFRecord files, we need to specify the expected format of each feature:

python
feature_description = {
'feature1': tf.io.FixedLenFeature([], tf.float32),
'feature2': tf.io.FixedLenFeature([3], tf.int64),
'feature3': tf.io.FixedLenFeature([], tf.string)
}

Step 2: Parse the examples

We need a function to parse each serialized example:

python
def _parse_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)

Step 3: Read the TFRecord file

Now, let's create a dataset from our TFRecord file and parse each example:

python
raw_dataset = tf.data.TFRecordDataset(filename)
parsed_dataset = raw_dataset.map(_parse_function)

# Print the parsed data
for parsed_record in parsed_dataset:
print("Feature1:", parsed_record['feature1'].numpy())
print("Feature2:", parsed_record['feature2'].numpy())
print("Feature3:", parsed_record['feature3'].numpy().decode('utf-8'))
print("---")

Output:

Feature1: 10.5
Feature2: [1 2 3]
Feature3: Hello
---
Feature1: -5.0
Feature2: [4 5 6]
Feature3: World
---
Feature1: 7.2
Feature2: [7 8 9]
Feature3: TFRecord
---

Handling Different Data Types in TFRecords

TFRecord files can store various types of data. Here's how to handle different data types:

Numeric Features

For numeric data, use tf.train.FloatList or tf.train.Int64List:

python
# For floating point values
float_feature = tf.train.Feature(float_list=tf.train.FloatList(value=[3.14, 1.618]))

# For integer values
int_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[10, 20, 30]))

String Features

For text data, use tf.train.BytesList:

python
text_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b'hello', b'world']))

Image Data

Images can be stored as serialized bytes:

python
import cv2

# Read an image
image = cv2.imread('image.jpg')
# Convert image to bytes
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()

# Create a bytes_feature from image bytes
image_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))

Real-World Application: Image Dataset

Let's look at a more practical example: storing an image dataset in TFRecord format.

Step 1: Prepare image data

python
import os
import glob
import cv2
import numpy as np
import tensorflow as tf

# Let's assume we have images and labels in these lists
image_paths = glob.glob('images/*.jpg')
labels = [0, 1, 0, 1, 2] # Corresponding labels

Step 2: Create a function to convert images and labels to TFRecord

python
def image_example(image_path, label):
# Read image
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB

# Encode image
_, encoded_image = cv2.imencode('.jpg', image)

feature = {
'height': _int64_feature([image.shape[0]]),
'width': _int64_feature([image.shape[1]]),
'channels': _int64_feature([image.shape[2]]),
'label': _int64_feature([label]),
'image_raw': _bytes_feature(encoded_image.tobytes())
}

return tf.train.Example(features=tf.train.Features(feature=feature))

Step 3: Write images to TFRecord file

python
# Create TFRecord file
record_file = 'images.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
for i, image_path in enumerate(image_paths):
if i < len(labels): # Ensure we have a label for this image
tf_example = image_example(image_path, labels[i])
writer.write(tf_example.SerializeToString())

print(f"TFRecord file with images saved at: {record_file}")

Step 4: Read and display the images from TFRecord

python
def parse_image_function(example_proto):
# Create a description of the features
feature_description = {
'height': tf.io.FixedLenFeature([1], tf.int64),
'width': tf.io.FixedLenFeature([1], tf.int64),
'channels': tf.io.FixedLenFeature([1], tf.int64),
'label': tf.io.FixedLenFeature([1], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}

# Parse the example
parsed_features = tf.io.parse_single_example(example_proto, feature_description)

# Decode the image
image = tf.io.decode_jpeg(parsed_features['image_raw'])

# Set the shape of the image
height = tf.cast(parsed_features['height'][0], tf.int32)
width = tf.cast(parsed_features['width'][0], tf.int32)
channels = tf.cast(parsed_features['channels'][0], tf.int32)
image = tf.reshape(image, [height, width, channels])

# Get the label
label = parsed_features['label'][0]

return image, label

# Create dataset from TFRecord file
image_dataset = tf.data.TFRecordDataset(record_file)
parsed_image_dataset = image_dataset.map(parse_image_function)

# Let's look at our data
count = 0
for image, label in parsed_image_dataset:
print(f"Image shape: {image.shape}, Label: {label.numpy()}")
count += 1

# Display first 2 images
if count <= 2:
plt.figure(figsize=(6, 6))
plt.imshow(image.numpy())
plt.title(f"Label: {label.numpy()}")
plt.axis('off')
plt.show()

TFRecord with Compression

TFRecords can be compressed to save disk space. Here's how to write and read compressed TFRecord files:

Writing compressed TFRecord

python
# Write compressed TFRecord
filename = 'compressed.tfrecord'
options = tf.io.TFRecordOptions(compression_type="GZIP")
with tf.io.TFRecordWriter(filename, options) as writer:
for data_point in feature_data:
example = create_example(data_point)
writer.write(example.SerializeToString())

print(f"Compressed TFRecord file saved at: {filename}")

Reading compressed TFRecord

python
# Read compressed TFRecord
raw_dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
parsed_dataset = raw_dataset.map(_parse_function)

for parsed_record in parsed_dataset:
print("Feature1:", parsed_record['feature1'].numpy())

Sharded TFRecord Files

For very large datasets, it's often beneficial to split the data into multiple "shards" (separate TFRecord files). This allows for:

  1. Faster parallel reading
  2. More efficient distributed training
  3. Better fault tolerance

Here's how to create sharded TFRecord files:

python
num_shards = 3
data_points_per_shard = len(feature_data) // num_shards

for i in range(num_shards):
filename = f'data-{i:05d}-of-{num_shards:05d}.tfrecord'
start_idx = i * data_points_per_shard
end_idx = start_idx + data_points_per_shard if i < num_shards - 1 else len(feature_data)

with tf.io.TFRecordWriter(filename) as writer:
for j in range(start_idx, end_idx):
example = create_example(feature_data[j])
writer.write(example.SerializeToString())

print(f"Shard {i+1}/{num_shards} saved at: {filename}")

To read all sharded files:

python
# Create a pattern that matches all shards
file_pattern = 'data-*-of-*.tfrecord'
dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(file_pattern))
parsed_dataset = dataset.map(_parse_function)

print("Reading from all shards:")
for parsed_record in parsed_dataset:
print("Feature1:", parsed_record['feature1'].numpy())

Performance Tips for TFRecord

To maximize the performance benefits of TFRecord:

  1. Use the TFRecord format for large datasets: For small datasets, the overhead may not be worth it.

  2. Batch your data: Use the batch() method on your dataset:

    python
    dataset = dataset.batch(32)
  3. Prefetch data: Keep the GPU fed by prefetching the next batch:

    python
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
  4. Cache when possible: For smaller datasets that fit in memory:

    python
    dataset = dataset.cache()
  5. Parallelize data reading:

    python
    dataset = tf.data.TFRecordDataset(filenames, 
    num_parallel_reads=tf.data.AUTOTUNE)

Summary

In this tutorial, we've learned:

  • What TFRecord files are and why they're useful for efficient data storage
  • How to create and write TFRecord files with different types of data
  • How to read and parse TFRecord files
  • How to work with images in TFRecord format
  • Advanced techniques like compression and sharding
  • Performance optimization tips for working with TFRecords

TFRecord is a powerful format that can significantly improve your TensorFlow data pipeline performance, especially for large datasets. By storing your data in TFRecord format, you can make your training workflows more efficient and reduce bottlenecks related to data loading.

Additional Resources

Exercises

  1. Create a TFRecord file containing numeric features for a regression task (like house price prediction).
  2. Convert an existing classification dataset (like MNIST or CIFAR-10) to TFRecord format.
  3. Write a TFRecord file with mixed data types (text, images, and numeric data) and then read it back.
  4. Implement a data pipeline that reads from sharded TFRecord files and applies data augmentation.
  5. Benchmark the performance difference between reading raw image files vs reading from TFRecord format.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)