TensorFlow Image Classification
Introduction
Image classification is one of the most common applications of deep learning and computer vision. It involves training a model to categorize images into predefined classes. For example, a model might be trained to recognize whether an image contains a cat, dog, car, or airplane.
In this tutorial, we'll explore how to build image classification models using TensorFlow and Convolutional Neural Networks (CNNs). CNNs are particularly well-suited for image-related tasks because they can automatically learn spatial hierarchies of features from the input images.
By the end of this tutorial, you'll be able to:
- Prepare image data for training
- Build a CNN for image classification
- Train and evaluate your model
- Make predictions with your trained model
Prerequisites
Before diving into this tutorial, you should have:
- Basic understanding of Python
- Familiarity with neural networks concepts
- TensorFlow installed on your machine
If you haven't installed TensorFlow yet, you can do so with pip:
pip install tensorflow
Setting Up Your Environment
Let's start by importing the necessary libraries:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
Dataset Preparation
For this tutorial, we'll use the CIFAR-10 dataset, which is a well-known benchmark dataset in computer vision. It contains 60,000 color images in 10 different classes (such as airplanes, cars, birds, cats, etc.), with 6,000 images per class.
Let's load the dataset:
# Load and split the CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0
# Define class names for better interpretation
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
Let's visualize some sample images from our dataset to get a better understanding of what we're working with:
# Display some sample images
plt.figure(figsize=(10, 10))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i])
# The CIFAR labels are arrays, so we need to access the first element
plt.xlabel(class_names[train_labels[i][0]])
plt.show()
Output: The code above will display a 5x5 grid of images from the CIFAR-10 dataset, each labeled with its respective class.
Building the CNN Model
Now, let's build a simple CNN for image classification. A typical CNN architecture for image classification includes:
- Convolutional layers for feature extraction
- Pooling layers for spatial dimension reduction
- Dense (fully connected) layers for classification
Here's how to build a basic CNN model using TensorFlow's Keras API:
# Create the CNN model
model = models.Sequential([
# First convolutional layer
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
layers.MaxPooling2D((2, 2)),
# Second convolutional layer
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
# Third convolutional layer
layers.Conv2D(64, (3, 3), activation='relu'),
# Flatten the output for the dense layers
layers.Flatten(),
# Dense layers for classification
layers.Dense(64, activation='relu'),
layers.Dense(10) # 10 output neurons for the 10 classes
])
# Display the model architecture
model.summary()
Output:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 30, 30, 32) 896
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 13, 13, 64) 18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 4, 4, 64) 36928
_________________________________________________________________
flatten (Flatten) (None, 1024) 0
_________________________________________________________________
dense (Dense) (None, 64) 65600
_________________________________________________________________
dense_1 (Dense) (None, 10) 650
=================================================================
Total params: 122,570
Trainable params: 122,570
Non-trainable params: 0
_________________________________________________________________
Compiling and Training the Model
Next, we need to compile the model by specifying the loss function, optimizer, and metrics:
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
Now, let's train our model:
# Train the model
history = model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))
Output:
Epoch 1/10
1563/1563 [==============================] - 14s 9ms/step - loss: 1.5248 - accuracy: 0.4458 - val_loss: 1.2408 - val_accuracy: 0.5513
Epoch 2/10
1563/1563 [==============================] - 13s 8ms/step - loss: 1.1591 - accuracy: 0.5889 - val_loss: 1.0897 - val_accuracy: 0.6142
...
Epoch 10/10
1563/1563 [==============================] - 13s 9ms/step - loss: 0.6509 - accuracy: 0.7762 - val_loss: 0.8784 - val_accuracy: 0.7124
Evaluating the Model Performance
Let's visualize the training and validation accuracy over epochs:
# Plot training & validation accuracy values
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
Let's evaluate the model's performance on the test data:
# Evaluate the model
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')
Output:
313/313 - 1s - loss: 0.8784 - accuracy: 0.7124
Test accuracy: 0.7124
Making Predictions with the Trained Model
Now that we have a trained model, let's use it to make predictions on some test images:
# Get predictions for all test images
predictions = model.predict(test_images)
# Convert from logits to probabilities
predictions_proba = tf.nn.softmax(predictions).numpy()
# Function to plot image with prediction
def plot_image(i, predictions_array, true_label, img):
true_label, img = true_label[i][0], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary)
predicted_label = np.argmax(predictions_array[i])
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'
plt.xlabel(f"{class_names[predicted_label]} {100*np.max(predictions_array[i]):0.1f}% "
f"({class_names[true_label]})", color=color)
# Let's see the model's predictions for the first 15 test images
plt.figure(figsize=(15, 6))
for i in range(15):
plt.subplot(3, 5, i+1)
plot_image(i, predictions_proba, test_labels, test_images)
plt.tight_layout()
plt.show()
Improving the Model
Our initial model gives decent performance, but there are several ways we can improve it:
- Data Augmentation: Create transformed versions of existing training images to increase the diversity of the training data.
# Create a data augmentation layer
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip("horizontal"),
layers.experimental.preprocessing.RandomRotation(0.1),
layers.experimental.preprocessing.RandomZoom(0.1),
])
# Create an improved model with data augmentation
improved_model = models.Sequential([
# Data augmentation layers
data_augmentation,
# Normalization layer
layers.experimental.preprocessing.Rescaling(1./255),
# Convolutional layers
layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),
# Flatten and dense layers
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dropout(0.5), # Add dropout to prevent overfitting
layers.Dense(10)
])
# Compile the improved model
improved_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Train the improved model (note: we don't need to normalize the images manually now)
improved_history = improved_model.fit(
train_images, train_labels,
validation_data=(test_images, test_labels),
epochs=15
)
- Transfer Learning: Use a pre-trained model as a starting point instead of training from scratch.
# Load a pre-trained model
base_model = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3),
include_top=False,
weights='imagenet'
)
# Freeze the base model
base_model.trainable = False
# Create a new model on top
transfer_model = models.Sequential([
# Resize images to the expected size for MobileNetV2
layers.experimental.preprocessing.Resizing(160, 160),
# Pre-trained base model
base_model,
# Add custom layers
layers.GlobalAveragePooling2D(),
layers.Dense(10)
])
# Compile the model
transfer_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Train the model (would need to resize CIFAR-10 images first)
# Note: This is just an example and would need more setup for actual training
Real-World Application: Building a Custom Image Classifier
Let's look at a more practical example of using TensorFlow for image classification. Suppose we want to build a custom image classifier that can identify different types of flowers.
This example uses a smaller subset of the Flowers dataset:
# Download and prepare the flowers dataset
import tensorflow_datasets as tfds
# Load the flowers dataset
dataset, info = tfds.load('tf_flowers', with_info=True, as_supervised=True)
train_dataset = dataset['train']
# Get the class names
class_names = info.features['label'].names
print(f"Class names: {class_names}")
# Function to prepare the data
def prepare_for_training(ds, shuffle_buffer_size=1000, batch_size=32):
# Resize and normalize the images
ds = ds.map(lambda image, label: (tf.image.resize(image, [224, 224]) / 255.0, label))
# Shuffle and batch the dataset
ds = ds.shuffle(buffer_size=shuffle_buffer_size)
ds = ds.batch(batch_size)
return ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
# Prepare the dataset for training
train_dataset = prepare_for_training(train_dataset)
# Create a model using a pre-trained network
base_model = tf.keras.applications.MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet'
)
base_model.trainable = False
# Add our custom layers
flower_model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(len(class_names))
])
# Compile the model
flower_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# Train the model
history = flower_model.fit(
train_dataset,
epochs=10
)
# Now you can use this model to classify flower images
# For example:
for image_batch, label_batch in train_dataset.take(1):
# Get predictions for a batch of images
predictions = flower_model.predict(image_batch)
predicted_classes = tf.argmax(predictions, axis=1)
# Display the first 5 images with predictions
plt.figure(figsize=(10, 10))
for i in range(min(5, len(image_batch))):
plt.subplot(1, 5, i + 1)
plt.imshow(image_batch[i])
plt.title(f"True: {class_names[label_batch[i]]}\nPred: {class_names[predicted_classes[i]]}")
plt.axis("off")
plt.tight_layout()
plt.show()
Summary
In this tutorial, we've covered:
- How to prepare image data for classification tasks
- Building a basic CNN model for image classification using TensorFlow
- Training and evaluating the model
- Making predictions using the trained model
- Techniques to improve model performance (data augmentation and transfer learning)
- A real-world example of building a custom flower classifier
Image classification is a foundational task in computer vision with numerous practical applications ranging from medical diagnostics to autonomous vehicles. The techniques we've explored here can be extended to more complex tasks and datasets.
Additional Resources
- TensorFlow Image Classification Tutorial
- TensorFlow Data Augmentation Guide
- Transfer Learning and Fine-tuning
- TensorFlow Hub for Pre-trained Models
Exercises
To reinforce your learning, try these exercises:
-
Exercise 1: Modify the CNN architecture we built by adding or removing layers, and observe the impact on model performance.
-
Exercise 2: Try implementing data augmentation with different transformations (e.g., contrast adjustments, cutout) and evaluate if they improve model performance.
-
Exercise 3: Use a different pre-trained model (e.g., ResNet, Inception) for transfer learning and compare its performance with our MobileNetV2-based model.
-
Exercise 4: Train a classifier on a different dataset, such as the Fashion MNIST dataset or a subset of ImageNet.
-
Challenge: Build an end-to-end image classification system that can:
- Load an image from disk or URL
- Preprocess the image appropriately
- Make a prediction using your trained model
- Display the result with confidence scores for different classes
Happy coding and experimenting with TensorFlow image classification!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)