Skip to main content

TensorFlow Style Guide

Introduction

Writing clean, consistent, and maintainable code is essential for any software project, and TensorFlow applications are no exception. The TensorFlow Style Guide provides a set of conventions and best practices that help developers write code that is readable, efficient, and aligns with the broader TensorFlow ecosystem.

This guide will help you understand the recommended style conventions for TensorFlow code, making your projects more accessible to other developers and ensuring that your code follows established best practices in the TensorFlow community.

Why Follow a Style Guide?

Before diving into the specifics, let's understand why a style guide is important:

  1. Consistency: Makes code more readable across projects and teams
  2. Maintainability: Makes code easier to debug and update
  3. Collaboration: Helps team members understand each other's code
  4. Quality: Promotes best practices that lead to fewer bugs

TensorFlow Naming Conventions

Variable Names

When working with TensorFlow, use descriptive variable names that indicate the purpose and type of the data:

python
# Good examples
learning_rate = 0.01
hidden_layer_sizes = [128, 64, 32]
training_examples = tf.constant([[1.0, 2.0], [3.0, 4.0]])

# Avoid short, cryptic names
lr = 0.01 # Too short
h_l_s = [128, 64, 32] # Unclear abbreviation
x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) # Not descriptive enough

Tensor Names

For tensors and operations, it's useful to include information about tensor dimensions or data types:

python
# Good examples
image_batch = tf.zeros([32, 128, 128, 3]) # [batch_size, height, width, channels]
embedding_vectors = tf.zeros([1000, 256]) # [vocabulary_size, embedding_dim]
logits_tensor = tf.zeros([32, 10]) # [batch_size, num_classes]

Model and Layer Names

When defining models and layers, use descriptive class names:

python
class ImageClassifier(tf.keras.Model):
def __init__(self):
super(ImageClassifier, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(128, activation='relu')
self.dense2 = tf.keras.layers.Dense(10)

Code Organization

Import Statements

Organize your imports in the following order:

python
# 1. Standard library imports
import os
import sys
from datetime import datetime

# 2. Third-party imports
import numpy as np
import pandas as pd

# 3. TensorFlow imports
import tensorflow as tf
from tensorflow.keras import layers, models

# 4. Local application imports
from myproject.data import preprocessing
from myproject.models import architecture

Function and Class Structure

Structure your functions and classes logically:

python
def preprocess_data(data, normalize=True):
"""Preprocess the input data.

Args:
data: Input data to be processed
normalize: Whether to normalize the data

Returns:
Processed data
"""
# Implementation here
return processed_data


class MyModel(tf.keras.Model):
"""A custom model for some specific task.

This model implements...
"""

def __init__(self, param1, param2):
"""Initialize the model.

Args:
param1: Description of param1
param2: Description of param2
"""
super(MyModel, self).__init__()
# Initialize layers

def call(self, inputs, training=None):
"""Forward pass of the model.

Args:
inputs: Input tensor
training: Whether in training mode

Returns:
Output tensor
"""
# Implementation here
return outputs

TensorFlow-Specific Best Practices

Use TensorFlow 2.x Style

TensorFlow 2.x promotes an eager execution, object-oriented approach. Prefer this over the TensorFlow 1.x style:

python
# TensorFlow 2.x style (recommended)
model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu'),
tf.keras.layers.Dense(1)
])

# Rather than TensorFlow 1.x style
x = tf.placeholder(tf.float32, shape=[None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b

Use tf.data for Input Pipelines

Prefer tf.data for creating efficient input pipelines:

python
# Create an efficient input pipeline
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

# Can be chained for readability
dataset = tf.data.Dataset.from_tensor_slices((features, labels)) \
.shuffle(buffer_size=1000) \
.batch(32) \
.prefetch(tf.data.AUTOTUNE)

Keras API Conventions

When using Keras, follow these conventions:

python
# Define a model using the Functional API
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(128, activation='relu')(inputs)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(10)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# Or use the Sequential API for linear stacks of layers
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])

# Compile with clear names
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

Documentation

Docstrings

Add clear docstrings to functions, classes, and methods:

python
def create_cnn_model(input_shape, num_classes):
"""Creates a CNN model for image classification.

Args:
input_shape: Shape of input images, e.g., (height, width, channels)
num_classes: Number of output classes

Returns:
A compiled Keras model

Example:
```python
model = create_cnn_model((28, 28, 1), 10)
history = model.fit(train_images, train_labels, epochs=5)

"""

Implementation here

return model


### Comments

Add comments to explain complex parts of your code:

```python
# Apply a custom learning rate schedule
# We use cosine decay with warmup for better convergence
initial_learning_rate = 0.1
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
initial_learning_rate,
decay_steps=1000,
alpha=0.01
)

# Warmup phase - linearly increase lr for first 100 steps
def warmup_cosine_decay_schedule(step):
warmup_steps = 100
if step < warmup_steps:
return initial_learning_rate * (step / warmup_steps)
else:
return lr_schedule(step - warmup_steps)

Real-World Example: Image Classification Model

Let's apply these style guidelines to a complete image classification example:

python
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import numpy as np

def load_and_preprocess_data():
"""Load and preprocess the CIFAR-10 dataset.

Returns:
Tuple of training and test datasets
"""
# Load the CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']

return (train_images, train_labels), (test_images, test_labels), class_names

def create_data_pipeline(images, labels, batch_size=64, is_training=True):
"""Create an optimized data pipeline.

Args:
images: Image data
labels: Label data
batch_size: Size of batches
is_training: Whether this is for training (includes shuffling)

Returns:
A tf.data.Dataset
"""
dataset = tf.data.Dataset.from_tensor_slices((images, labels))

if is_training:
# Shuffle the data
dataset = dataset.shuffle(buffer_size=10000)

# Add data augmentation for training
def augment(image, label):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.1)
return image, label

dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)

# Batch and prefetch for performance
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

return dataset

class CNNModel(models.Model):
"""Convolutional Neural Network for image classification."""

def __init__(self, num_classes=10):
"""Initialize the CNN model.

Args:
num_classes: Number of output classes
"""
super(CNNModel, self).__init__()

# Feature extraction layers
self.conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')
self.batch_norm1 = layers.BatchNormalization()
self.pool1 = layers.MaxPooling2D((2, 2))

self.conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')
self.batch_norm2 = layers.BatchNormalization()
self.pool2 = layers.MaxPooling2D((2, 2))

self.conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')
self.batch_norm3 = layers.BatchNormalization()
self.pool3 = layers.MaxPooling2D((2, 2))

# Classification layers
self.flatten = layers.Flatten()
self.dense1 = layers.Dense(256, activation='relu')
self.dropout = layers.Dropout(0.5)
self.dense2 = layers.Dense(num_classes)

def call(self, inputs, training=False):
"""Forward pass of the model.

Args:
inputs: Input images
training: Whether in training mode

Returns:
Output logits
"""
# First block
x = self.conv1(inputs)
x = self.batch_norm1(x, training=training)
x = self.pool1(x)

# Second block
x = self.conv2(x)
x = self.batch_norm2(x, training=training)
x = self.pool2(x)

# Third block
x = self.conv3(x)
x = self.batch_norm3(x, training=training)
x = self.pool3(x)

# Classification head
x = self.flatten(x)
x = self.dense1(x)
if training:
x = self.dropout(x)
x = self.dense2(x)

return x

def train_model(model, train_dataset, validation_dataset, epochs=10):
"""Train the model.

Args:
model: The model to train
train_dataset: Training dataset
validation_dataset: Validation dataset
epochs: Number of epochs to train

Returns:
Training history
"""
# Define callbacks
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
'best_model.h5',
save_best_only=True,
monitor='val_accuracy'
)

early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
)

# Train the model
history = model.fit(
train_dataset,
epochs=epochs,
validation_data=validation_dataset,
callbacks=[checkpoint_callback, early_stopping]
)

return history

def main():
# Load and preprocess data
(train_images, train_labels), (test_images, test_labels), class_names = load_and_preprocess_data()

# Create data pipelines
train_dataset = create_data_pipeline(train_images, train_labels, is_training=True)
test_dataset = create_data_pipeline(test_images, test_labels, is_training=False)

# Create the model
model = CNNModel(num_classes=10)

# Compile the model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

# Print model summary
model.build((None, 32, 32, 3))
model.summary()

# Train the model
history = train_model(model, train_dataset, test_dataset, epochs=15)

# Evaluate on test data
test_loss, test_acc = model.evaluate(test_dataset)
print(f'Test accuracy: {test_acc:.4f}')

if __name__ == '__main__':
main()

This example demonstrates:

  • Clear module structure
  • Proper docstrings
  • Descriptive variable names
  • Logical function organization
  • Proper use of TensorFlow 2.x and Keras APIs
  • Efficient data pipeline with tf.data
  • Custom model with clear layer organization

Common Pitfalls to Avoid

  1. Mixing TensorFlow 1.x and 2.x styles: Be consistent with the TensorFlow 2.x eager execution style.

  2. Inconsistent variable naming: Don't mix naming conventions (e.g., snake_case and camelCase).

  3. Magic numbers: Avoid hardcoding values without explanation.

    python
    # Avoid this
    x = input_tensor * 0.017453292519943295

    # Better approach
    DEGREES_TO_RADIANS = 0.017453292519943295 # π/180
    x = input_tensor * DEGREES_TO_RADIANS
  4. Overly complex functions: Break down complex logic into smaller, reusable functions.

  5. Missing documentation: Always include docstrings for functions and classes.

Summary

Following a consistent style guide is essential for writing clean, maintainable TensorFlow code. In this guide, we covered:

  • Naming conventions for variables, tensors, models, and layers
  • Code organization best practices
  • TensorFlow 2.x style recommendations
  • Documentation standards
  • A complete real-world example demonstrating these principles

By adhering to these guidelines, you'll write TensorFlow code that is not only functional but also readable, maintainable, and aligned with community best practices.

Additional Resources

Exercises

  1. Take a small TensorFlow script you've written and refactor it according to this style guide.

  2. Review an open-source TensorFlow project on GitHub and identify how it follows (or doesn't follow) these style conventions.

  3. Create a small image classification model following these style guidelines. Make sure to include proper documentation, naming conventions, and code organization.

  4. Refactor the following code snippet to follow the TensorFlow style guide:

    python
    def fn(x,y):
    a=tf.matmul(x,y)
    b=tf.reduce_sum(a,0)
    c=tf.nn.relu(b)
    return c
  5. Write a custom data augmentation pipeline using tf.data that follows the style guidelines presented in this guide.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)