TensorFlow TFRecord
Introduction
When working with large datasets in machine learning applications, efficient data storage and retrieval becomes crucial for performance. TensorFlow's TFRecord format provides a solution by offering a simple binary file format that allows you to efficiently store and read sequence of binary records.
TFRecord offers several advantages:
- Efficient storage: Data is stored in a binary format, which is more space-efficient than text formats
- Improved I/O performance: Reading binary data is faster than parsing text formats
- Atomic records: Each example is stored as a self-contained record, making it easy to shuffle and batch data
- Cross-platform compatibility: Works consistently across different operating systems and environments
In this tutorial, we'll learn how to create, write, and read TFRecord files in TensorFlow, and understand when and why you should use this format.
Understanding TFRecord Format
TFRecord is TensorFlow's binary file format designed specifically for storing a sequence of binary records. Each record in a TFRecord file contains:
- Length: The size of the data in bytes
- CRC-32C checksum of the length (for data integrity)
- Data: The actual serialized data (typically a serialized
tf.train.Example
protocol buffer) - CRC-32C checksum of the data (for data integrity)
Creating TFRecord Files
Let's start by creating a simple TFRecord file with some basic data.
Step 1: Import necessary libraries
import tensorflow as tf
import numpy as np
Step 2: Create example data
Let's create some simple data that we'll store in our TFRecord file:
# Sample data: 3 examples with features
feature_data = [
{
'feature1': 10.5,
'feature2': [1, 2, 3],
'feature3': 'Hello'
},
{
'feature1': -5.0,
'feature2': [4, 5, 6],
'feature3': 'World'
},
{
'feature1': 7.2,
'feature2': [7, 8, 9],
'feature3': 'TFRecord'
}
]
Step 3: Convert data to TensorFlow Examples
TFRecord files typically store data as serialized tf.train.Example
protocol buffers. We need to convert our Python data to this format:
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode('utf-8')]))
def create_example(data_point):
feature = {
'feature1': _float_feature(data_point['feature1']),
'feature2': _int64_feature(data_point['feature2']),
'feature3': _bytes_feature(data_point['feature3'])
}
return tf.train.Example(features=tf.train.Features(feature=feature))
Step 4: Write data to a TFRecord file
Now, let's write our examples to a TFRecord file:
filename = 'example.tfrecord'
with tf.io.TFRecordWriter(filename) as writer:
for data_point in feature_data:
example = create_example(data_point)
writer.write(example.SerializeToString())
print(f"TFRecord file saved at: {filename}")
Output:
TFRecord file saved at: example.tfrecord
Reading TFRecord Files
Now that we've created a TFRecord file, let's see how to read the data back.
Step 1: Create a feature description
When reading TFRecord files, we need to specify the expected format of each feature:
feature_description = {
'feature1': tf.io.FixedLenFeature([], tf.float32),
'feature2': tf.io.FixedLenFeature([3], tf.int64),
'feature3': tf.io.FixedLenFeature([], tf.string)
}
Step 2: Parse the examples
We need a function to parse each serialized example:
def _parse_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
Step 3: Read the TFRecord file
Now, let's create a dataset from our TFRecord file and parse each example:
raw_dataset = tf.data.TFRecordDataset(filename)
parsed_dataset = raw_dataset.map(_parse_function)
# Print the parsed data
for parsed_record in parsed_dataset:
print("Feature1:", parsed_record['feature1'].numpy())
print("Feature2:", parsed_record['feature2'].numpy())
print("Feature3:", parsed_record['feature3'].numpy().decode('utf-8'))
print("---")
Output:
Feature1: 10.5
Feature2: [1 2 3]
Feature3: Hello
---
Feature1: -5.0
Feature2: [4 5 6]
Feature3: World
---
Feature1: 7.2
Feature2: [7 8 9]
Feature3: TFRecord
---
Handling Different Data Types in TFRecords
TFRecord files can store various types of data. Here's how to handle different data types:
Numeric Features
For numeric data, use tf.train.FloatList
or tf.train.Int64List
:
# For floating point values
float_feature = tf.train.Feature(float_list=tf.train.FloatList(value=[3.14, 1.618]))
# For integer values
int_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[10, 20, 30]))
String Features
For text data, use tf.train.BytesList
:
text_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b'hello', b'world']))
Image Data
Images can be stored as serialized bytes:
import cv2
# Read an image
image = cv2.imread('image.jpg')
# Convert image to bytes
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
# Create a bytes_feature from image bytes
image_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
Real-World Application: Image Dataset
Let's look at a more practical example: storing an image dataset in TFRecord format.
Step 1: Prepare image data
import os
import glob
import cv2
import numpy as np
import tensorflow as tf
# Let's assume we have images and labels in these lists
image_paths = glob.glob('images/*.jpg')
labels = [0, 1, 0, 1, 2] # Corresponding labels
Step 2: Create a function to convert images and labels to TFRecord
def image_example(image_path, label):
# Read image
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
# Encode image
_, encoded_image = cv2.imencode('.jpg', image)
feature = {
'height': _int64_feature([image.shape[0]]),
'width': _int64_feature([image.shape[1]]),
'channels': _int64_feature([image.shape[2]]),
'label': _int64_feature([label]),
'image_raw': _bytes_feature(encoded_image.tobytes())
}
return tf.train.Example(features=tf.train.Features(feature=feature))
Step 3: Write images to TFRecord file
# Create TFRecord file
record_file = 'images.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
for i, image_path in enumerate(image_paths):
if i < len(labels): # Ensure we have a label for this image
tf_example = image_example(image_path, labels[i])
writer.write(tf_example.SerializeToString())
print(f"TFRecord file with images saved at: {record_file}")