Skip to main content

TensorFlow Knowledge Distillation

Introduction

Knowledge distillation is a model compression technique that enables transferring knowledge from a large, complex model (the "teacher") to a smaller, more efficient model (the "student"). Introduced by Geoffrey Hinton and his team in 2015, this approach allows us to deploy lightweight models that maintain much of the performance of larger models - making them suitable for mobile devices, edge computing, and scenarios with limited computational resources.

In this tutorial, we'll explore how to implement knowledge distillation using TensorFlow, understand the underlying concepts, and see how it can be applied to real-world problems.

Why Knowledge Distillation Matters

Before diving into the implementation, let's understand why knowledge distillation is important:

  1. Model Efficiency: Smaller models require fewer computational resources and less memory
  2. Deployment Flexibility: Compressed models can run on resource-constrained devices
  3. Inference Speed: Smaller models typically provide faster inference times
  4. Energy Efficiency: Less computation translates to lower energy consumption

The Theory Behind Knowledge Distillation

Knowledge distillation works by training a smaller student model to mimic the behavior of a larger pre-trained teacher model. The key insight is that we don't just use the hard class labels for training, but instead leverage the "soft targets" (probability distributions) produced by the teacher model.

The basic knowledge distillation loss combines two components:

  1. Distillation Loss: How well the student matches the teacher's soft probabilities
  2. Student Loss: How well the student performs on the original task

Let's break down the key concepts:

  • Temperature (τ): A hyperparameter that "softens" probability distributions, revealing more information about the teacher's learned relationships
  • Soft targets: Probability distributions from the teacher model, softened by temperature
  • Hard targets: The original ground truth labels

Implementing Knowledge Distillation in TensorFlow

Let's implement knowledge distillation step by step using TensorFlow:

Step 1: Import Required Libraries

python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Input, Dropout
from tensorflow.keras.models import Model

Step 2: Prepare the Dataset

We'll use the MNIST dataset for this example:

python
# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize the data
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape for CNN input
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

# One-hot encode the labels
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

print(f"Training data shape: {x_train.shape}")
print(f"Testing data shape: {x_test.shape}")

Output:

Training data shape: (60000, 28, 28, 1)
Testing data shape: (10000, 28, 28, 1)

Step 3: Create Teacher Model (Complex Model)

First, let's create our larger teacher model:

python
def create_teacher_model():
inputs = Input(shape=(28, 28, 1))
x = Conv2D(32, kernel_size=(3, 3), activation='relu')(inputs)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
outputs = Dense(10, activation='softmax', name='predictions')(x)

model = Model(inputs=inputs, outputs=outputs)
return model

teacher_model = create_teacher_model()
teacher_model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)

teacher_model.summary()

Output:

Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 26, 26, 32) 320
_________________________________________________________________
...
_________________________________________________________________
predictions (Dense) (None, 10) 1290
=================================================================
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0
_________________________________________________________________

Step 4: Train the Teacher Model

Let's train our teacher model:

python
teacher_model.fit(
x_train, y_train,
epochs=10,
batch_size=128,
validation_data=(x_test, y_test),
verbose=1
)

# Evaluate the teacher model
teacher_score = teacher_model.evaluate(x_test, y_test, verbose=0)
print(f"Teacher model accuracy: {teacher_score[1]:.4f}")

Output:

Epoch 1/10
469/469 [==============================] - 24s 51ms/step - loss: 0.3615 - accuracy: 0.8868 - val_loss: 0.0796 - val_accuracy: 0.9756
...
Epoch 10/10
469/469 [==============================] - 24s 50ms/step - loss: 0.0346 - accuracy: 0.9893 - val_loss: 0.0326 - val_accuracy: 0.9901
Teacher model accuracy: 0.9901

Step 5: Create Student Model (Simplified Model)

Now, let's create a smaller, simpler student model:

python
def create_student_model():
inputs = Input(shape=(28, 28, 1))
x = Flatten()(inputs)
x = Dense(32, activation='relu')(x)
outputs = Dense(10, activation='softmax', name='predictions')(x)

model = Model(inputs=inputs, outputs=outputs)
return model

student_model = create_student_model()
student_model.summary()

Output:

Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
flatten_1 (Flatten) (None, 784) 0
_________________________________________________________________
dense_1 (Dense) (None, 32) 25120
_________________________________________________________________
predictions (Dense) (None, 10) 330
=================================================================
Total params: 25,450
Trainable params: 25,450
Non-trainable params: 0
_________________________________________________________________

As you can see, the student model has far fewer parameters than the teacher model!

Step 6: Implement the Knowledge Distillation Loss

Now, let's implement our knowledge distillation mechanism:

python
class DistillationModel(Model):
def __init__(self, student, teacher, temp=5.0, alpha=0.5):
super(DistillationModel, self).__init__()
self.student = student
self.teacher = teacher
self.temp = temp
self.alpha = alpha

def compile(self, optimizer, metrics=None):
super(DistillationModel, self).compile(optimizer=optimizer, metrics=metrics)
self.distillation_loss_fn = tf.keras.losses.KLDivergence()
self.student_loss_fn = tf.keras.losses.CategoricalCrossentropy()

def train_step(self, data):
# Unpack the data
x, y = data

# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)

with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)

# Compute student loss (with hard targets)
student_loss = self.student_loss_fn(y, student_predictions)

# Apply softmax with temperature for distillation
temp_teacher = tf.nn.softmax(teacher_predictions / self.temp)
temp_student = tf.nn.softmax(student_predictions / self.temp)

# Compute distillation loss (soft targets)
distillation_loss = self.distillation_loss_fn(temp_teacher, temp_student) * (self.temp ** 2)

# Combine losses: α*student_loss + (1-α)*distillation_loss
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

# Compute gradients
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)

# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))

# Update metrics
self.compiled_metrics.update_state(y, student_predictions)

# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": student_loss, "distillation_loss": distillation_loss})
return results

Step 7: Train the Student Model with Knowledge Distillation

Now, we'll train the student model using our knowledge distillation approach:

python
# Freeze the teacher model
teacher_model.trainable = False

# Create and compile the distillation model
distiller = DistillationModel(
student=student_model,
teacher=teacher_model,
temp=5.0, # Temperature parameter
alpha=0.5 # Balance between student loss and distillation loss
)

distiller.compile(
optimizer=tf.keras.optimizers.Adam(1e-3),
metrics=['accuracy']
)

# Train the student model with distillation
history = distiller.fit(
x_train, y_train,
epochs=15,
batch_size=128,
validation_data=(x_test, y_test)
)

Output:

Epoch 1/15
469/469 [==============================] - 3s 6ms/step - accuracy: 0.8619 - student_loss: 0.4203 - distillation_loss: 0.1943 - val_accuracy: 0.9600
...
Epoch 15/15
469/469 [==============================] - 3s 6ms/step - accuracy: 0.9766 - student_loss: 0.0662 - distillation_loss: 0.0198 - val_accuracy: 0.9752

Step 8: Compare Results

Let's evaluate both models on the test set and compare their performance:

python
# Train a student model without distillation for comparison
student_model_vanilla = create_student_model()
student_model_vanilla.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)

student_model_vanilla.fit(
x_train, y_train,
epochs=15,
batch_size=128,
validation_data=(x_test, y_test),
verbose=0
)

# Evaluate all models
teacher_score = teacher_model.evaluate(x_test, y_test, verbose=0)
student_distill_score = student_model.evaluate(x_test, y_test, verbose=0)
student_vanilla_score = student_model_vanilla.evaluate(x_test, y_test, verbose=0)

print(f"Teacher model accuracy: {teacher_score[1]:.4f}")
print(f"Student with distillation accuracy: {student_distill_score[1]:.4f}")
print(f"Student without distillation accuracy: {student_vanilla_score[1]:.4f}")

# Model size comparison
teacher_params = teacher_model.count_params()
student_params = student_model.count_params()
print(f"Teacher model parameters: {teacher_params:,}")
print(f"Student model parameters: {student_params:,}")
print(f"Model size reduction: {(1 - student_params/teacher_params) * 100:.2f}%")

Output:

Teacher model accuracy: 0.9901
Student with distillation accuracy: 0.9752
Student without distillation accuracy: 0.9710
Teacher model parameters: 1,199,882
Student model parameters: 25,450
Model size reduction: 97.88%

Step 9: Visualize the Results

Let's create some visualizations to better understand the knowledge distillation process:

python
def plot_model_comparison():
labels = ['Teacher', 'Student with Distillation', 'Student without Distillation']
accuracies = [teacher_score[1], student_distill_score[1], student_vanilla_score[1]]
params = [teacher_params/1000000, student_params/1000000, student_params/1000000]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Accuracy comparison
ax1.bar(labels, accuracies, color=['blue', 'green', 'red'])
ax1.set_title('Model Accuracy Comparison')
ax1.set_ylim([0.95, 1.0]) # Focus on the relevant range
ax1.set_ylabel('Accuracy')
for i, v in enumerate(accuracies):
ax1.text(i, v + 0.005, f"{v:.4f}", ha='center')

# Model size comparison
ax2.bar(labels, params, color=['blue', 'green', 'red'])
ax2.set_title('Model Size Comparison (Million Parameters)')
ax2.set_ylabel('Parameters (Millions)')
for i, v in enumerate(params):
ax2.text(i, v + 0.05, f"{v:.2f}M", ha='center')

plt.tight_layout()
plt.show()

plot_model_comparison()

This visualization would show the accuracy and size comparison between the models.

Real-world Applications of Knowledge Distillation

Knowledge distillation has numerous practical applications:

1. Mobile and Edge Device Deployment

A common use case for knowledge distillation is deploying models to mobile applications:

python
# Example: Converting and optimizing for TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_keras_model(student_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TFLite model
with open('mnist_student_model.tflite', 'wb') as f:
f.write(tflite_model)

print(f"Original model size: {len(tf.io.serialize_keras_model(teacher_model).numpy()) / 1024:.2f} KB")
print(f"Distilled TFLite model size: {len(tflite_model) / 1024:.2f} KB")

2. Ensemble Knowledge Distillation

You can distill knowledge from multiple teacher models into a single student:

python
def ensemble_distillation(x, teacher_models, student_model, temp=5.0, alpha=0.1):
# Get predictions from all teachers
teacher_preds = [teacher(x, training=False) for teacher in teacher_models]

# Average teacher predictions
ensemble_pred = tf.reduce_mean(teacher_preds, axis=0)

# Student prediction
student_pred = student_model(x, training=True)

# Apply temperature scaling
soft_targets = tf.nn.softmax(ensemble_pred / temp)
soft_prob = tf.nn.softmax(student_pred / temp)

# Calculate distillation loss
distill_loss = tf.keras.losses.kullback_leibler_divergence(soft_targets, soft_prob) * (temp ** 2)

return distill_loss

3. Cross-Modal Distillation

Knowledge distillation can transfer knowledge between different types of neural networks or even different modalities:

python
# Hypothetical example: Distilling from a vision transformer to a CNN
def cross_modal_distill(image_input, text_input, vision_teacher, text_student):
# Get embeddings from vision teacher
vision_embedding = vision_teacher(image_input, training=False)

# Get embeddings from text student
text_embedding = text_student(text_input, training=True)

# Align embeddings through distillation
alignment_loss = tf.reduce_mean(
tf.square(tf.nn.l2_normalize(vision_embedding, axis=1) -
tf.nn.l2_normalize(text_embedding, axis=1))
)

return alignment_loss

Best Practices for Knowledge Distillation

  1. Temperature Selection: Higher temperatures make the probability distribution softer, exposing more knowledge about relationships between classes
  2. Alpha Balance: Find the right balance between hard-label loss and distillation loss
  3. Student Architecture Design: Choose an appropriate student architecture with sufficient capacity to learn from the teacher
  4. Pre-training: Consider pre-training the student on hard labels before distillation
  5. Feature Distillation: For deeper networks, consider distilling intermediate layer activations as well

Summary

In this tutorial, we've covered:

  • The concept of knowledge distillation and why it's useful
  • How to implement knowledge distillation in TensorFlow
  • Creating teacher and student models with different architectures
  • Training a student model to mimic a teacher model's outputs
  • Comparing performance between vanilla training and distillation
  • Real-world applications of knowledge distillation

Knowledge distillation allows us to create smaller, faster models that maintain most of the performance of larger models. This technique is particularly valuable for deploying deep learning models in resource-constrained environments.

Additional Resources

Exercises

  1. Try different temperature values (1, 3, 5, 10) and observe how they affect the distillation performance
  2. Experiment with different student architectures to find the smallest model that maintains acceptable performance
  3. Implement "attention distillation" by transferring intermediate layer activations from teacher to student
  4. Apply knowledge distillation to a different dataset like CIFAR-10
  5. Explore quantization-aware distillation to further compress your model for edge devices


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)