Skip to main content

TensorFlow Image Classification

Introduction

Image classification is one of the most common applications of deep learning and computer vision. It involves training a model to categorize images into predefined classes. For example, a model might be trained to recognize whether an image contains a cat, dog, car, or airplane.

In this tutorial, we'll explore how to build image classification models using TensorFlow and Convolutional Neural Networks (CNNs). CNNs are particularly well-suited for image-related tasks because they can automatically learn spatial hierarchies of features from the input images.

By the end of this tutorial, you'll be able to:

  • Prepare image data for training
  • Build a CNN for image classification
  • Train and evaluate your model
  • Make predictions with your trained model

Prerequisites

Before diving into this tutorial, you should have:

  • Basic understanding of Python
  • Familiarity with neural networks concepts
  • TensorFlow installed on your machine

If you haven't installed TensorFlow yet, you can do so with pip:

bash
pip install tensorflow

Setting Up Your Environment

Let's start by importing the necessary libraries:

python
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np

Dataset Preparation

For this tutorial, we'll use the CIFAR-10 dataset, which is a well-known benchmark dataset in computer vision. It contains 60,000 color images in 10 different classes (such as airplanes, cars, birds, cats, etc.), with 6,000 images per class.

Let's load the dataset:

python
# Load and split the CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

# Define class names for better interpretation
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']

Let's visualize some sample images from our dataset to get a better understanding of what we're working with:

python
# Display some sample images
plt.figure(figsize=(10, 10))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i])
# The CIFAR labels are arrays, so we need to access the first element
plt.xlabel(class_names[train_labels[i][0]])
plt.show()

Output: The code above will display a 5x5 grid of images from the CIFAR-10 dataset, each labeled with its respective class.

Building the CNN Model

Now, let's build a simple CNN for image classification. A typical CNN architecture for image classification includes:

  1. Convolutional layers for feature extraction
  2. Pooling layers for spatial dimension reduction
  3. Dense (fully connected) layers for classification

Here's how to build a basic CNN model using TensorFlow's Keras API:

python
# Create the CNN model
model = models.Sequential([
# First convolutional layer
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
layers.MaxPooling2D((2, 2)),

# Second convolutional layer
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),

# Third convolutional layer
layers.Conv2D(64, (3, 3), activation='relu'),

# Flatten the output for the dense layers
layers.Flatten(),

# Dense layers for classification
layers.Dense(64, activation='relu'),
layers.Dense(10) # 10 output neurons for the 10 classes
])

# Display the model architecture
model.summary()

Output:

Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 30, 30, 32) 896
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 13, 13, 64) 18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 4, 4, 64) 36928
_________________________________________________________________
flatten (Flatten) (None, 1024) 0
_________________________________________________________________
dense (Dense) (None, 64) 65600
_________________________________________________________________
dense_1 (Dense) (None, 10) 650
=================================================================
Total params: 122,570
Trainable params: 122,570
Non-trainable params: 0
_________________________________________________________________

Compiling and Training the Model

Next, we need to compile the model by specifying the loss function, optimizer, and metrics:

python
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

Now, let's train our model:

python
# Train the model
history = model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))

Output:

Epoch 1/10
1563/1563 [==============================] - 14s 9ms/step - loss: 1.5248 - accuracy: 0.4458 - val_loss: 1.2408 - val_accuracy: 0.5513
Epoch 2/10
1563/1563 [==============================] - 13s 8ms/step - loss: 1.1591 - accuracy: 0.5889 - val_loss: 1.0897 - val_accuracy: 0.6142
...
Epoch 10/10
1563/1563 [==============================] - 13s 9ms/step - loss: 0.6509 - accuracy: 0.7762 - val_loss: 0.8784 - val_accuracy: 0.7124

Evaluating the Model Performance

Let's visualize the training and validation accuracy over epochs:

python
# Plot training & validation accuracy values
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

Let's evaluate the model's performance on the test data:

python
# Evaluate the model
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')

Output:

313/313 - 1s - loss: 0.8784 - accuracy: 0.7124
Test accuracy: 0.7124

Making Predictions with the Trained Model

Now that we have a trained model, let's use it to make predictions on some test images:

python
# Get predictions for all test images
predictions = model.predict(test_images)

# Convert from logits to probabilities
predictions_proba = tf.nn.softmax(predictions).numpy()

# Function to plot image with prediction
def plot_image(i, predictions_array, true_label, img):
true_label, img = true_label[i][0], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])

plt.imshow(img, cmap=plt.cm.binary)

predicted_label = np.argmax(predictions_array[i])
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'

plt.xlabel(f"{class_names[predicted_label]} {100*np.max(predictions_array[i]):0.1f}% "
f"({class_names[true_label]})", color=color)

# Let's see the model's predictions for the first 15 test images
plt.figure(figsize=(15, 6))
for i in range(15):
plt.subplot(3, 5, i+1)
plot_image(i, predictions_proba, test_labels, test_images)
plt.tight_layout()
plt.show()

Improving the Model

Our initial model gives decent performance, but there are several ways we can improve it:

  1. Data Augmentation: Create transformed versions of existing training images to increase the diversity of the training data.
python
# Create a data augmentation layer
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip("horizontal"),
layers.experimental.preprocessing.RandomRotation(0.1),
layers.experimental.preprocessing.RandomZoom(0.1),
])

# Create an improved model with data augmentation
improved_model = models.Sequential([
# Data augmentation layers
data_augmentation,

# Normalization layer
layers.experimental.preprocessing.Rescaling(1./255),

# Convolutional layers
layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),

layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),

layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),

# Flatten and dense layers
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dropout(0.5), # Add dropout to prevent overfitting
layers.Dense(10)
])

# Compile the improved model
improved_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

# Train the improved model (note: we don't need to normalize the images manually now)
improved_history = improved_model.fit(
train_images, train_labels,
validation_data=(test_images, test_labels),
epochs=15
)
  1. Transfer Learning: Use a pre-trained model as a starting point instead of training from scratch.
python
# Load a pre-trained model
base_model = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3),
include_top=False,
weights='imagenet'
)

# Freeze the base model
base_model.trainable = False

# Create a new model on top
transfer_model = models.Sequential([
# Resize images to the expected size for MobileNetV2
layers.experimental.preprocessing.Resizing(160, 160),

# Pre-trained base model
base_model,

# Add custom layers
layers.GlobalAveragePooling2D(),
layers.Dense(10)
])

# Compile the model
transfer_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

# Train the model (would need to resize CIFAR-10 images first)
# Note: This is just an example and would need more setup for actual training

Real-World Application: Building a Custom Image Classifier

Let's look at a more practical example of using TensorFlow for image classification. Suppose we want to build a custom image classifier that can identify different types of flowers.

This example uses a smaller subset of the Flowers dataset:

python
# Download and prepare the flowers dataset
import tensorflow_datasets as tfds

# Load the flowers dataset
dataset, info = tfds.load('tf_flowers', with_info=True, as_supervised=True)
train_dataset = dataset['train']

# Get the class names
class_names = info.features['label'].names
print(f"Class names: {class_names}")

# Function to prepare the data
def prepare_for_training(ds, shuffle_buffer_size=1000, batch_size=32):
# Resize and normalize the images
ds = ds.map(lambda image, label: (tf.image.resize(image, [224, 224]) / 255.0, label))

# Shuffle and batch the dataset
ds = ds.shuffle(buffer_size=shuffle_buffer_size)
ds = ds.batch(batch_size)

return ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

# Prepare the dataset for training
train_dataset = prepare_for_training(train_dataset)

# Create a model using a pre-trained network
base_model = tf.keras.applications.MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet'
)

base_model.trainable = False

# Add our custom layers
flower_model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(len(class_names))
])

# Compile the model
flower_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

# Train the model
history = flower_model.fit(
train_dataset,
epochs=10
)

# Now you can use this model to classify flower images
# For example:
for image_batch, label_batch in train_dataset.take(1):
# Get predictions for a batch of images
predictions = flower_model.predict(image_batch)
predicted_classes = tf.argmax(predictions, axis=1)

# Display the first 5 images with predictions
plt.figure(figsize=(10, 10))
for i in range(min(5, len(image_batch))):
plt.subplot(1, 5, i + 1)
plt.imshow(image_batch[i])
plt.title(f"True: {class_names[label_batch[i]]}\nPred: {class_names[predicted_classes[i]]}")
plt.axis("off")
plt.tight_layout()
plt.show()

Summary

In this tutorial, we've covered:

  1. How to prepare image data for classification tasks
  2. Building a basic CNN model for image classification using TensorFlow
  3. Training and evaluating the model
  4. Making predictions using the trained model
  5. Techniques to improve model performance (data augmentation and transfer learning)
  6. A real-world example of building a custom flower classifier

Image classification is a foundational task in computer vision with numerous practical applications ranging from medical diagnostics to autonomous vehicles. The techniques we've explored here can be extended to more complex tasks and datasets.

Additional Resources

Exercises

To reinforce your learning, try these exercises:

  1. Exercise 1: Modify the CNN architecture we built by adding or removing layers, and observe the impact on model performance.

  2. Exercise 2: Try implementing data augmentation with different transformations (e.g., contrast adjustments, cutout) and evaluate if they improve model performance.

  3. Exercise 3: Use a different pre-trained model (e.g., ResNet, Inception) for transfer learning and compare its performance with our MobileNetV2-based model.

  4. Exercise 4: Train a classifier on a different dataset, such as the Fashion MNIST dataset or a subset of ImageNet.

  5. Challenge: Build an end-to-end image classification system that can:

    • Load an image from disk or URL
    • Preprocess the image appropriately
    • Make a prediction using your trained model
    • Display the result with confidence scores for different classes

Happy coding and experimenting with TensorFlow image classification!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)