TensorFlow Siamese Networks
Introduction
Siamese Networks are a special type of neural network architecture designed to learn similarity between inputs. Unlike conventional neural networks that classify or predict values, Siamese networks determine how similar (or different) two inputs are. Named after "Siamese twins," these networks share weights between two or more identical subnetworks, processing different inputs to compare them in a learned feature space.
In this tutorial, you'll learn:
- What Siamese Networks are and how they work
- When to use Siamese architectures over traditional networks
- How to implement a Siamese Network using TensorFlow
- Real-world applications of Siamese Networks
Understanding Siamese Networks
What Are Siamese Networks?
Siamese Networks consist of two (or more) identical neural networks with shared weights. Each network processes a different input, and the outputs are compared to determine similarity. This architecture is particularly useful for applications where we need to compare two examples rather than simply classify an individual example.
Key Components of Siamese Networks:
- Twin Networks: Two identical neural networks that share the exact same weights
- Feature Extraction: Each network extracts features from its input
- Distance Function: A mechanism to measure similarity between the extracted features
- Loss Function: Usually contrastive loss or triplet loss to train the network
Why Use Siamese Networks?
- Few-shot learning: Learn from very few examples per class
- Face verification: Determine if two faces belong to the same person
- Signature verification: Check if two signatures match
- Image similarity: Find similar images in a database
- Document matching: Compare document similarities
Building a Siamese Network in TensorFlow
Let's build a Siamese Network for comparing handwritten digits using the MNIST dataset. Our goal will be to determine whether two digit images represent the same number.
Step 1: Import Required Libraries
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, MaxPooling2D, Lambda
from tensorflow.keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
import random
Step 2: Prepare the Dataset
We need to structure our data differently for Siamese networks. Instead of individual examples and labels, we need pairs of examples and a binary label indicating whether they belong to the same class.
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize and reshape images
train_images = train_images / 255.0
test_images = test_images / 255.0
# Reshape for CNN
train_images = train_images.reshape(-1, 28, 28, 1)
test_images = test_images.reshape(-1, 28, 28, 1)
# Function to create pairs
def create_pairs(images, labels):
pairs = []
labels_pairs = []
# For each digit
for digit_index in range(10):
# Get all indices for this digit
digit_indices = np.where(labels == digit_index)[0]
# Make positive pairs (same digit)
for i in range(min(300, len(digit_indices)-1)):
idx1, idx2 = digit_indices[i], digit_indices[i+1]
pairs.append([images[idx1], images[idx2]])
labels_pairs.append(1) # 1 means same class
# Make negative pairs (different digits)
other_digits = list(range(10))
other_digits.remove(digit_index)
# Get indices for a randomly selected different digit
different_digit = random.choice(other_digits)
different_digit_indices = np.where(labels == different_digit)[0]
for i in range(min(300, len(digit_indices))):
idx1 = digit_indices[i]
idx2 = different_digit_indices[i % len(different_digit_indices)]
pairs.append([images[idx1], images[idx2]])
labels_pairs.append(0) # 0 means different class
return np.array(pairs), np.array(labels_pairs)
# Create training and testing pairs
train_pairs, train_pair_labels = create_pairs(train_images, train_labels)
test_pairs, test_pair_labels = create_pairs(test_images, test_labels)
# Print shapes
print("Training pairs shape:", train_pairs.shape)
print("Training labels shape:", train_pair_labels.shape)
print("Testing pairs shape:", test_pairs.shape)
print("Testing labels shape:", test_pair_labels.shape)
Step 3: Build the Base Network
Let's define the base network that will be shared between the twin networks:
def build_base_network(input_shape):
input_layer = Input(shape=input_shape)
# Convolutional layers
x = Conv2D(32, (3, 3), activation='relu')(input_layer)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)
# Dense layers for embedding
x = Dense(128, activation='relu')(x)
embedding = Dense(64, activation='relu')(x)
# Create the model
model = Model(inputs=input_layer, outputs=embedding)
return model
Step 4: Create the Siamese Network
Now, let's create the complete Siamese network architecture:
def euclidean_distance(vectors):
"""Compute the Euclidean distance between two vectors"""
x, y = vectors
sum_square = tf.math.reduce_sum(tf.math.square(x - y), axis=1, keepdims=True)
return tf.math.sqrt(sum_square)
def cosine_similarity(vectors):
"""Compute the cosine similarity between two vectors"""
x, y = vectors
x_norm = tf.math.l2_normalize(x, axis=1)
y_norm = tf.math.l2_normalize(y, axis=1)
return tf.math.reduce_sum(x_norm * y_norm, axis=1, keepdims=True)
def build_siamese_model(input_shape):
# Define the base network
base_network = build_base_network(input_shape)
# Define the two inputs
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
# Pass inputs through the base network
processed_a = base_network(input_a)
processed_b = base_network(input_b)
# Compute the distance between the outputs
distance = Lambda(euclidean_distance)([processed_a, processed_b])
# Create the model
model = Model(inputs=[input_a, input_b], outputs=distance)
return model
# Build the model
input_shape = (28, 28, 1)
siamese_model = build_siamese_model(input_shape)
# Compile the model
siamese_model.compile(
loss=contrastive_loss,
optimizer=tf.keras.optimizers.Adam(0.0001),
metrics=['accuracy']
)
# Model summary
siamese_model.summary()
Step 5: Define Custom Loss Function
For Siamese networks, we typically use a contrastive loss function:
def contrastive_loss(y_true, y_pred):
"""
Contrastive loss function.
Parameters:
y_true: Label (1 for similar pairs, 0 for dissimilar pairs)
y_pred: Euclidean distance between the pair embeddings
Returns:
Loss value
"""
margin = 1.0
# For similar pairs (y_true=1), we want the distance to be small
# For dissimilar pairs (y_true=0), we want the distance to be larger than the margin
similarity_part = y_true * tf.square(y_pred)
dissimilarity_part = (1 - y_true) * tf.square(tf.maximum(margin - y_pred, 0))
return tf.reduce_mean(similarity_part + dissimilarity_part) / 2
Step 6: Train the Model
Now let's train our Siamese network:
# Restructure data for training
input_a_train = train_pairs[:, 0] # First image in each pair
input_b_train = train_pairs[:, 1] # Second image in each pair
input_a_test = test_pairs[:, 0]
input_b_test = test_pairs[:, 1]
# Train the model
history = siamese_model.fit(
[input_a_train, input_b_train],
train_pair_labels,
validation_data=([input_a_test, input_b_test], test_pair_labels),
batch_size=64,
epochs=10
)
# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()
Step 7: Evaluate the Model and Make Predictions
Let's test our model on some example pairs:
# Function to visualize pairs and predictions
def visualize_predictions(pairs, true_labels, model, num_examples=5):
# Make predictions
predictions = model.predict([pairs[:, 0], pairs[:, 1]])
# Convert distances to similarity scores
similarity_scores = 1 - predictions # Higher score means more similar
# Set a threshold for prediction
threshold = 0.5
predicted_labels = (similarity_scores > threshold).astype(int).flatten()
# Plot examples
plt.figure(figsize=(15, 4 * num_examples))
for i in range(num_examples):
# Choose a random index
idx = np.random.randint(0, len(pairs))
plt.subplot(num_examples, 3, i*3 + 1)
plt.imshow(pairs[idx, 0].reshape(28, 28), cmap='gray')
plt.title("Image 1")
plt.axis('off')
plt.subplot(num_examples, 3, i*3 + 2)
plt.imshow(pairs[idx, 1].reshape(28, 28), cmap='gray')
plt.title("Image 2")
plt.axis('off')
plt.subplot(num_examples, 3, i*3 + 3)
plt.text(0.5, 0.5, f"True label: {'Same' if true_labels[idx] else 'Different'}\n" +
f"Predicted: {'Same' if predicted_labels[idx] else 'Different'}\n" +
f"Similarity Score: {similarity_scores[idx][0]:.4f}",
horizontalalignment='center', verticalalignment='center', fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.show()
# Visualize some test predictions
visualize_predictions(test_pairs, test_pair_labels, siamese_model, num_examples=5)
Real-World Applications of Siamese Networks
1. Face Recognition and Verification
Siamese Networks are widely used in facial recognition systems to verify if two face images belong to the same person.
# Pseudocode for a face verification system
def verify_face(known_face, unknown_face, siamese_model, threshold=0.7):
# Preprocess faces
known_face_processed = preprocess_face(known_face)
unknown_face_processed = preprocess_face(unknown_face)
# Get similarity score from model
similarity = 1 - siamese_model.predict([
np.expand_dims(known_face_processed, axis=0),
np.expand_dims(unknown_face_processed, axis=0)
])[0][0]
# Verify based on threshold
if similarity > threshold:
return True, similarity # Same person
else:
return False, similarity # Different person
2. Signature Verification
Banks and financial institutions use Siamese Networks to verify if a signature matches the one on record:
# Example of how a signature verification system might work
def verify_signature(stored_signature, provided_signature, model):
# Process signatures
processed_stored = preprocess_signature(stored_signature)
processed_provided = preprocess_signature(provided_signature)
# Get embeddings
stored_embedding = base_network.predict(np.expand_dims(processed_stored, axis=0))
provided_embedding = base_network.predict(np.expand_dims(processed_provided, axis=0))
# Calculate similarity
similarity = calculate_similarity(stored_embedding, provided_embedding)
return similarity > VERIFICATION_THRESHOLD
3. One-Shot Learning
One of the most powerful applications of Siamese Networks is one-shot learning - the ability to learn from just one example:
# Example of one-shot learning for classifying new objects
def one_shot_classification(known_examples, new_example, siamese_model):
# Dictionary to store similarities
similarities = {}
# Compare new example with each known example
for class_name, example in known_examples.items():
# Get embeddings using the base network
similarity = 1 - siamese_model.predict([
np.expand_dims(example, axis=0),
np.expand_dims(new_example, axis=0)
])[0][0]
similarities[class_name] = similarity
# Return class with highest similarity
return max(similarities, key=similarities.get)
4. Image Similarity Search
E-commerce platforms use Siamese Networks to find visually similar products:
# Pseudocode for image similarity search in a product catalog
def find_similar_products(query_image, product_database, siamese_model, top_k=5):
# Preprocess the query image
processed_query = preprocess_image(query_image)
query_embedding = base_network.predict(np.expand_dims(processed_query, axis=0))
# Compare with all products
similarities = []
for product_id, product_image in product_database.items():
product_embedding = base_network.predict(np.expand_dims(product_image, axis=0))
# Calculate similarity score
similarity = calculate_similarity(query_embedding, product_embedding)
similarities.append((product_id, similarity))
# Return top K similar products
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
Advanced Topics: Triplet Networks
Triplet Networks are an extension of Siamese Networks that use three inputs: an anchor, a positive (similar) example, and a negative (dissimilar) example. They're especially effective for tasks like face recognition.
def build_triplet_network(input_shape):
# Define the base network
base_network = build_base_network(input_shape)
# Define the three inputs
anchor_input = Input(shape=input_shape, name='anchor_input')
positive_input = Input(shape=input_shape, name='positive_input')
negative_input = Input(shape=input_shape, name='negative_input')
# Generate embeddings for all three inputs
anchor_embedding = base_network(anchor_input)
positive_embedding = base_network(positive_input)
negative_embedding = base_network(negative_input)
# Create the model
triplet_model = Model(
inputs=[anchor_input, positive_input, negative_input],
outputs=[anchor_embedding, positive_embedding, negative_embedding]
)
return triplet_model
# Define triplet loss function
def triplet_loss(_, y_pred):
anchor, positive, negative = y_pred
# Distance between anchor and positive
pos_dist = tf.reduce_sum(tf.square(anchor - positive), axis=1)
# Distance between anchor and negative
neg_dist = tf.reduce_sum(tf.square(anchor - negative), axis=1)
# Compute triplet loss with margin
margin = 1.0
basic_loss = pos_dist - neg_dist + margin
loss = tf.maximum(basic_loss, 0.0)
return tf.reduce_mean(loss)
Summary
In this tutorial, we've:
- Introduced Siamese Networks and their architecture
- Explained how they differ from traditional neural networks
- Built a complete Siamese Network for comparing MNIST digits
- Implemented custom contrastive loss for training
- Explored real-world applications like face recognition and one-shot learning
- Introduced the concept of Triplet Networks as an extension of Siamese Networks
Siamese Networks provide powerful solutions for similarity learning problems where traditional classification approaches might not be suitable. They excel at tasks that require learning with limited data and are particularly valuable for applications like biometric verification, image matching, and few-shot learning.
Additional Resources
- FaceNet: A Unified Embedding for Face Recognition and Clustering
- Signature Verification using a "Siamese" Time Delay Neural Network
- TensorFlow Similarity Library
Exercises
- Modify the Siamese Network to use a different distance metric (e.g., cosine similarity instead of Euclidean distance).
- Implement a Siamese Network for a different dataset, such as the Omniglot dataset (often used for one-shot learning).
- Create a triplet network for face recognition using a face dataset like LFW (Labeled Faces in the Wild).
- Implement a one-shot learning system that can classify new images from just a single example per class.
- Use a pre-trained model (like ResNet) as the base network for your Siamese Network and apply it to a real-world problem such as duplicate image detection.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)