TensorFlow TFRecord
Introduction
When working with large datasets in machine learning applications, efficient data storage and retrieval becomes crucial for performance. TensorFlow's TFRecord format provides a solution by offering a simple binary file format that allows you to efficiently store and read sequence of binary records.
TFRecord offers several advantages:
- Efficient storage: Data is stored in a binary format, which is more space-efficient than text formats
- Improved I/O performance: Reading binary data is faster than parsing text formats
- Atomic records: Each example is stored as a self-contained record, making it easy to shuffle and batch data
- Cross-platform compatibility: Works consistently across different operating systems and environments
In this tutorial, we'll learn how to create, write, and read TFRecord files in TensorFlow, and understand when and why you should use this format.
Understanding TFRecord Format
TFRecord is TensorFlow's binary file format designed specifically for storing a sequence of binary records. Each record in a TFRecord file contains:
- Length: The size of the data in bytes
- CRC-32C checksum of the length (for data integrity)
- Data: The actual serialized data (typically a serialized
tf.train.Example
protocol buffer) - CRC-32C checksum of the data (for data integrity)
Creating TFRecord Files
Let's start by creating a simple TFRecord file with some basic data.
Step 1: Import necessary libraries
import tensorflow as tf
import numpy as np
Step 2: Create example data
Let's create some simple data that we'll store in our TFRecord file:
# Sample data: 3 examples with features
feature_data = [
{
'feature1': 10.5,
'feature2': [1, 2, 3],
'feature3': 'Hello'
},
{
'feature1': -5.0,
'feature2': [4, 5, 6],
'feature3': 'World'
},
{
'feature1': 7.2,
'feature2': [7, 8, 9],
'feature3': 'TFRecord'
}
]
Step 3: Convert data to TensorFlow Examples
TFRecord files typically store data as serialized tf.train.Example
protocol buffers. We need to convert our Python data to this format:
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode('utf-8')]))
def create_example(data_point):
feature = {
'feature1': _float_feature(data_point['feature1']),
'feature2': _int64_feature(data_point['feature2']),
'feature3': _bytes_feature(data_point['feature3'])
}
return tf.train.Example(features=tf.train.Features(feature=feature))
Step 4: Write data to a TFRecord file
Now, let's write our examples to a TFRecord file:
filename = 'example.tfrecord'
with tf.io.TFRecordWriter(filename) as writer:
for data_point in feature_data:
example = create_example(data_point)
writer.write(example.SerializeToString())
print(f"TFRecord file saved at: {filename}")
Output:
TFRecord file saved at: example.tfrecord
Reading TFRecord Files
Now that we've created a TFRecord file, let's see how to read the data back.
Step 1: Create a feature description
When reading TFRecord files, we need to specify the expected format of each feature:
feature_description = {
'feature1': tf.io.FixedLenFeature([], tf.float32),
'feature2': tf.io.FixedLenFeature([3], tf.int64),
'feature3': tf.io.FixedLenFeature([], tf.string)
}
Step 2: Parse the examples
We need a function to parse each serialized example:
def _parse_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
Step 3: Read the TFRecord file
Now, let's create a dataset from our TFRecord file and parse each example:
raw_dataset = tf.data.TFRecordDataset(filename)
parsed_dataset = raw_dataset.map(_parse_function)
# Print the parsed data
for parsed_record in parsed_dataset:
print("Feature1:", parsed_record['feature1'].numpy())
print("Feature2:", parsed_record['feature2'].numpy())
print("Feature3:", parsed_record['feature3'].numpy().decode('utf-8'))
print("---")
Output:
Feature1: 10.5
Feature2: [1 2 3]
Feature3: Hello
---
Feature1: -5.0
Feature2: [4 5 6]
Feature3: World
---
Feature1: 7.2
Feature2: [7 8 9]
Feature3: TFRecord
---
Handling Different Data Types in TFRecords
TFRecord files can store various types of data. Here's how to handle different data types:
Numeric Features
For numeric data, use tf.train.FloatList
or tf.train.Int64List
:
# For floating point values
float_feature = tf.train.Feature(float_list=tf.train.FloatList(value=[3.14, 1.618]))
# For integer values
int_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[10, 20, 30]))
String Features
For text data, use tf.train.BytesList
:
text_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b'hello', b'world']))
Image Data
Images can be stored as serialized bytes:
import cv2
# Read an image
image = cv2.imread('image.jpg')
# Convert image to bytes
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
# Create a bytes_feature from image bytes
image_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
Real-World Application: Image Dataset
Let's look at a more practical example: storing an image dataset in TFRecord format.
Step 1: Prepare image data
import os
import glob
import cv2
import numpy as np
import tensorflow as tf
# Let's assume we have images and labels in these lists
image_paths = glob.glob('images/*.jpg')
labels = [0, 1, 0, 1, 2] # Corresponding labels
Step 2: Create a function to convert images and labels to TFRecord
def image_example(image_path, label):
# Read image
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
# Encode image
_, encoded_image = cv2.imencode('.jpg', image)
feature = {
'height': _int64_feature([image.shape[0]]),
'width': _int64_feature([image.shape[1]]),
'channels': _int64_feature([image.shape[2]]),
'label': _int64_feature([label]),
'image_raw': _bytes_feature(encoded_image.tobytes())
}
return tf.train.Example(features=tf.train.Features(feature=feature))
Step 3: Write images to TFRecord file
# Create TFRecord file
record_file = 'images.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
for i, image_path in enumerate(image_paths):
if i < len(labels): # Ensure we have a label for this image
tf_example = image_example(image_path, labels[i])
writer.write(tf_example.SerializeToString())
print(f"TFRecord file with images saved at: {record_file}")
Step 4: Read and display the images from TFRecord
def parse_image_function(example_proto):
# Create a description of the features
feature_description = {
'height': tf.io.FixedLenFeature([1], tf.int64),
'width': tf.io.FixedLenFeature([1], tf.int64),
'channels': tf.io.FixedLenFeature([1], tf.int64),
'label': tf.io.FixedLenFeature([1], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
# Parse the example
parsed_features = tf.io.parse_single_example(example_proto, feature_description)
# Decode the image
image = tf.io.decode_jpeg(parsed_features['image_raw'])
# Set the shape of the image
height = tf.cast(parsed_features['height'][0], tf.int32)
width = tf.cast(parsed_features['width'][0], tf.int32)
channels = tf.cast(parsed_features['channels'][0], tf.int32)
image = tf.reshape(image, [height, width, channels])
# Get the label
label = parsed_features['label'][0]
return image, label
# Create dataset from TFRecord file
image_dataset = tf.data.TFRecordDataset(record_file)
parsed_image_dataset = image_dataset.map(parse_image_function)
# Let's look at our data
count = 0
for image, label in parsed_image_dataset:
print(f"Image shape: {image.shape}, Label: {label.numpy()}")
count += 1
# Display first 2 images
if count <= 2:
plt.figure(figsize=(6, 6))
plt.imshow(image.numpy())
plt.title(f"Label: {label.numpy()}")
plt.axis('off')
plt.show()
TFRecord with Compression
TFRecords can be compressed to save disk space. Here's how to write and read compressed TFRecord files:
Writing compressed TFRecord
# Write compressed TFRecord
filename = 'compressed.tfrecord'
options = tf.io.TFRecordOptions(compression_type="GZIP")
with tf.io.TFRecordWriter(filename, options) as writer:
for data_point in feature_data:
example = create_example(data_point)
writer.write(example.SerializeToString())
print(f"Compressed TFRecord file saved at: {filename}")
Reading compressed TFRecord
# Read compressed TFRecord
raw_dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
parsed_dataset = raw_dataset.map(_parse_function)
for parsed_record in parsed_dataset:
print("Feature1:", parsed_record['feature1'].numpy())
Sharded TFRecord Files
For very large datasets, it's often beneficial to split the data into multiple "shards" (separate TFRecord files). This allows for:
- Faster parallel reading
- More efficient distributed training
- Better fault tolerance
Here's how to create sharded TFRecord files:
num_shards = 3
data_points_per_shard = len(feature_data) // num_shards
for i in range(num_shards):
filename = f'data-{i:05d}-of-{num_shards:05d}.tfrecord'
start_idx = i * data_points_per_shard
end_idx = start_idx + data_points_per_shard if i < num_shards - 1 else len(feature_data)
with tf.io.TFRecordWriter(filename) as writer:
for j in range(start_idx, end_idx):
example = create_example(feature_data[j])
writer.write(example.SerializeToString())
print(f"Shard {i+1}/{num_shards} saved at: {filename}")
To read all sharded files:
# Create a pattern that matches all shards
file_pattern = 'data-*-of-*.tfrecord'
dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(file_pattern))
parsed_dataset = dataset.map(_parse_function)
print("Reading from all shards:")
for parsed_record in parsed_dataset:
print("Feature1:", parsed_record['feature1'].numpy())
Performance Tips for TFRecord
To maximize the performance benefits of TFRecord:
-
Use the TFRecord format for large datasets: For small datasets, the overhead may not be worth it.
-
Batch your data: Use the
batch()
method on your dataset:pythondataset = dataset.batch(32)
-
Prefetch data: Keep the GPU fed by prefetching the next batch:
pythondataset = dataset.prefetch(tf.data.AUTOTUNE)
-
Cache when possible: For smaller datasets that fit in memory:
pythondataset = dataset.cache()
-
Parallelize data reading:
pythondataset = tf.data.TFRecordDataset(filenames,
num_parallel_reads=tf.data.AUTOTUNE)
Summary
In this tutorial, we've learned:
- What TFRecord files are and why they're useful for efficient data storage
- How to create and write TFRecord files with different types of data
- How to read and parse TFRecord files
- How to work with images in TFRecord format
- Advanced techniques like compression and sharding
- Performance optimization tips for working with TFRecords
TFRecord is a powerful format that can significantly improve your TensorFlow data pipeline performance, especially for large datasets. By storing your data in TFRecord format, you can make your training workflows more efficient and reduce bottlenecks related to data loading.
Additional Resources
- Official TensorFlow TFRecord Guide
- tf.data: Build TensorFlow input pipelines
- Protocol Buffers Documentation
Exercises
- Create a TFRecord file containing numeric features for a regression task (like house price prediction).
- Convert an existing classification dataset (like MNIST or CIFAR-10) to TFRecord format.
- Write a TFRecord file with mixed data types (text, images, and numeric data) and then read it back.
- Implement a data pipeline that reads from sharded TFRecord files and applies data augmentation.
- Benchmark the performance difference between reading raw image files vs reading from TFRecord format.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)