TensorFlow Code Organization
Introduction
When you're just starting with TensorFlow, it's easy to focus solely on making your models work. However, as your projects grow in complexity, good code organization becomes crucial. Well-organized TensorFlow code is easier to debug, share with others, and reuse across projects. This guide will walk you through best practices for structuring your TensorFlow code, from small scripts to large-scale machine learning projects.
Why Code Organization Matters
Before diving into specific patterns, let's understand why code organization is particularly important for TensorFlow projects:
- Reproducibility: Well-organized code makes it easier to reproduce results
- Collaboration: Team members can understand and contribute to the codebase
- Maintenance: Debugging and updating models becomes simpler
- Scalability: Code structure that works well as projects grow in complexity
- Deployment: Clean separation makes production deployment smoother
Basic Project Structure
For any TensorFlow project beyond simple experiments, consider this basic directory structure:
my_tensorflow_project/
├── data/ # Data files and preprocessing scripts
├── models/ # Model definitions
├── config/ # Configuration files
├── train.py # Training script
├── evaluate.py # Evaluation script
├── predict.py # Inference script
├── utils/ # Utility functions
└── README.md # Project documentation
This structure separates concerns and makes your project easier to navigate. Let's explore how to implement this effectively.
Modularizing Model Code
The Model Class Approach
Rather than defining models in the same file as your training code, create dedicated model classes:
# models/simple_cnn.py
import tensorflow as tf
class SimpleCNN(tf.keras.Model):
def __init__(self, num_classes):
super(SimpleCNN, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(128, activation='relu')
self.dense2 = tf.keras.layers.Dense(num_classes)
def call(self, inputs, training=False):
x = self.conv1(inputs)
x = self.flatten(x)
x = self.dense1(x)
return self.dense2(x)
Then in your training code:
# train.py
from models.simple_cnn import SimpleCNN
model = SimpleCNN(num_classes=10)
# Configure training and fit the model...
Using Keras Functional API
For complex models, the Keras Functional API allows for cleaner designs:
# models/complex_model.py
import tensorflow as tf
def create_complex_model(input_shape, num_classes):
inputs = tf.keras.Input(shape=input_shape)
# Feature extraction path
x = tf.keras.layers.Conv2D(32, 3, activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D()(x)
x = tf.keras.layers.Conv2D(64, 3, activation='relu')(x)
x = tf.keras.layers.MaxPooling2D()(x)
# Classification head
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
return tf.keras.Model(inputs=inputs, outputs=outputs)
Configuration Management
Instead of hardcoding parameters, use configuration files or objects:
# config/model_config.py
class ModelConfig:
def __init__(self):
# Model architecture
self.input_shape = (28, 28, 1)
self.num_classes = 10
# Training parameters
self.batch_size = 32
self.epochs = 10
self.learning_rate = 0.001
# Paths
self.data_dir = "./data"
self.checkpoint_dir = "./checkpoints"
config = ModelConfig()
Using this in your code:
# train.py
from config.model_config import config
from models.complex_model import create_complex_model
# Create model using config parameters
model = create_complex_model(config.input_shape, config.num_classes)
# Configure optimizer with config
optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)
# And so on...
Data Pipeline Organization
Separating data handling code makes your projects more maintainable:
# data/mnist_data.py
import tensorflow as tf
def load_and_prepare_data(batch_size):
# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Normalize and reshape
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)
# Create tf.data.Dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size)
return train_dataset, test_dataset
Custom Training Loops
For advanced use cases, organize custom training loops into functions:
# train.py
import tensorflow as tf
from models.simple_cnn import SimpleCNN
from data.mnist_data import load_and_prepare_data
from config.model_config import config
def train_step(model, optimizer, x_batch, y_batch):
with tf.GradientTape() as tape:
logits = model(x_batch, training=True)
loss_value = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True)(y_batch, logits)
grads = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss_value
def train_model():
# Setup
model = SimpleCNN(num_classes=config.num_classes)
optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)
train_dataset, test_dataset = load_and_prepare_data(config.batch_size)
# Training loop
for epoch in range(config.epochs):
print(f"Epoch {epoch+1}/{config.epochs}")
epoch_loss_avg = tf.keras.metrics.Mean()
for x_batch, y_batch in train_dataset:
loss = train_step(model, optimizer, x_batch, y_batch)
epoch_loss_avg.update_state(loss)
print(f"Loss: {epoch_loss_avg.result().numpy()}")
# Save the trained model
model.save(f"{config.checkpoint_dir}/final_model")
if __name__ == "__main__":
train_model()
Callbacks and Logging
Organize monitoring and logging with callbacks:
# utils/callbacks.py
import tensorflow as tf
import os
def get_callbacks(config):
# Create checkpoint directory if it doesn't exist
if not os.path.exists(config.checkpoint_dir):
os.makedirs(config.checkpoint_dir)
callbacks = [
# Save checkpoints during training
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(config.checkpoint_dir, "model_{epoch}"),
save_best_only=True,
monitor='val_loss'
),
# Early stopping to prevent overfitting
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
),
# TensorBoard logging
tf.keras.callbacks.TensorBoard(
log_dir='./logs',
histogram_freq=1
)
]
return callbacks
Using these callbacks:
# In train.py
from utils.callbacks import get_callbacks
# For models trained with model.fit()
model.fit(
train_dataset,
epochs=config.epochs,
validation_data=test_dataset,
callbacks=get_callbacks(config)
)
Real-World Example: Image Classification Project
Let's put all these concepts together with a complete example of an image classification project:
# Project structure
# image_classifier/
# ├── config/
# │ └── config.py
# ├── data/
# │ └── data_loader.py
# ├── models/
# │ └── resnet_classifier.py
# ├── utils/
# │ ├── callbacks.py
# │ └── visualization.py
# ├── train.py
# └── evaluate.py
# config/config.py
class Config:
def __init__(self):
self.input_shape = (224, 224, 3)
self.num_classes = 5
self.batch_size = 32
self.epochs = 20
self.learning_rate = 0.0001
self.data_dir = "./data/flower_photos"
self.checkpoint_dir = "./checkpoints"
# data/data_loader.py
import tensorflow as tf
import pathlib
def load_dataset(config):
data_dir = pathlib.Path(config.data_dir)
# Create dataset from directory structure
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(224, 224),
batch_size=config.batch_size
)
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(224, 224),
batch_size=config.batch_size
)
# Performance optimization
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
return train_ds, val_ds
# models/resnet_classifier.py
import tensorflow as tf
def create_model(config):
# Use pre-trained ResNet50 as base
base_model = tf.keras.applications.ResNet50(
weights='imagenet',
include_top=False,
input_shape=config.input_shape
)
# Freeze the base model
base_model.trainable = False
# Add custom classifier head
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(config.num_classes)
])
return model
# train.py
import tensorflow as tf
from config.config import Config
from data.data_loader import load_dataset
from models.resnet_classifier import create_model
from utils.callbacks import get_callbacks
def train():
# Initialize configuration
config = Config()
# Load data
train_ds, val_ds = load_dataset(config)
# Create model
model = create_model(config)
# Compile model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=config.learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# Train model
model.fit(
train_ds,
validation_data=val_ds,
epochs=config.epochs,
callbacks=get_callbacks(config)
)
# Save final model
model.save('flower_classifier')
if __name__ == "__main__":
train()
Containerizing Your TensorFlow Project
For production-ready code organization, consider using Docker:
# Dockerfile
FROM tensorflow/tensorflow:2.11.0-gpu
WORKDIR /app
# Copy requirements and install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy project files
COPY . .
# Default command
CMD ["python", "train.py"]
Summary and Best Practices
Let's summarize the key points for organizing TensorFlow code:
- Separate concerns with a clear directory structure
- Modularize model definitions in dedicated files
- Use configuration objects or files instead of hardcoding parameters
- Organize data pipelines in their own modules
- Create utility functions for common operations
- Use callbacks for monitoring and checkpointing
- Document your code thoroughly
By following these practices, you'll create TensorFlow projects that are easier to maintain, collaborate on, and eventually deploy to production.
Additional Resources
Exercises
- Take an existing TensorFlow script and reorganize it according to the structure outlined in this guide.
- Create a configuration file for a project and modify your code to use the config values instead of hardcoded parameters.
- Implement a custom training loop that includes logging, checkpointing, and early stopping.
- Refactor a model defined within a training script into its own module with clear documentation.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)