PyTorch Attention Mechanisms
In the world of Natural Language Processing (NLP), attention mechanisms have revolutionized how neural networks understand and process language. They're the secret sauce behind powerful models like BERT, GPT, and other transformer architectures. But what exactly are attention mechanisms, and how do we implement them in PyTorch?
This tutorial will guide you through understanding attention mechanisms from the ground up, with practical PyTorch implementations to help solidify your understanding.
What is Attention?
At its core, attention is a technique that allows a model to focus on specific parts of input data when producing an output. Think of how you read this sentence—your eyes may jump back to earlier words to establish context. Attention mechanisms let neural networks do something similar.
In NLP, attention helps models decide which words in a sentence are most important for understanding its meaning, especially when dealing with tasks like translation or summarization where context is crucial.
Types of Attention Mechanisms
Before diving into code, let's understand the main types of attention:
- Basic Attention: Assigning importance weights to input elements
- Self-Attention: Relating different positions in a single sequence
- Multi-Head Attention: Running attention multiple times in parallel
Implementing Basic Attention in PyTorch
Let's start by implementing a simple attention mechanism in PyTorch:
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicAttention(nn.Module):
def __init__(self, hidden_size):
super(BasicAttention, self).__init__()
# Attention projection layer
self.attention = nn.Linear(hidden_size, 1)
def forward(self, encoder_outputs):
# encoder_outputs shape: [batch_size, seq_len, hidden_size]
# Calculate attention weights
attention_weights = self.attention(encoder_outputs) # [batch_size, seq_len, 1]
attention_weights = F.softmax(attention_weights.squeeze(-1), dim=1) # [batch_size, seq_len]
# Apply attention weights to get context vector
attention_weights = attention_weights.unsqueeze(2) # [batch_size, seq_len, 1]
context = torch.sum(encoder_outputs * attention_weights, dim=1) # [batch_size, hidden_size]
return context, attention_weights.squeeze(2)
How Basic Attention Works:
- We take sequence outputs from an encoder (like LSTM outputs)
- We calculate an "importance score" for each element in the sequence
- We normalize these scores using softmax to get attention weights
- We compute a weighted sum of the inputs based on these weights
Let's see this in action:
# Sample data
batch_size = 3
seq_len = 5
hidden_size = 10
# Create random encoder outputs
encoder_outputs = torch.randn(batch_size, seq_len, hidden_size)
# Initialize attention module
attention = BasicAttention(hidden_size)
# Apply attention
context, attention_weights = attention(encoder_outputs)
print("Encoder outputs shape:", encoder_outputs.shape)
print("Context vector shape:", context.shape)
print("Attention weights shape:", attention_weights.shape)
print("\nAttention weights (sample):")
print(attention_weights[0]) # Print weights for first item in batch
Output:
Encoder outputs shape: torch.Size([3, 5, 10])
Context vector shape: torch.Size([3, 10])
Attention weights shape: torch.Size([3, 5])
Attention weights (sample):
tensor([0.1924, 0.1889, 0.1887, 0.2231, 0.2069], grad_fn=<SelectBackward0>)
Self-Attention: The Building Block of Transformers
Self-attention is where things get really interesting! It's the foundation of transformer models and allows each position in a sequence to attend to all positions, capturing relationships regardless of their distance.
Let's implement a simple self-attention mechanism:
class SelfAttention(nn.Module):
def __init__(self, embed_size):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
# Query, Key, Value projections
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
# Scaling factor
self.scale = torch.sqrt(torch.FloatTensor([embed_size]))
def forward(self, x, mask=None):
# x shape: [batch_size, seq_len, embed_size]
batch_size, seq_len, _ = x.shape
# Create Q, K, V projections
Q = self.query(x) # [batch_size, seq_len, embed_size]
K = self.key(x) # [batch_size, seq_len, embed_size]
V = self.value(x) # [batch_size, seq_len, embed_size]
# Calculate attention scores
# Transpose K for matrix multiplication
K_t = K.transpose(1, 2) # [batch_size, embed_size, seq_len]
# Calculate raw attention scores
energy = torch.bmm(Q, K_t) / self.scale.to(x.device) # [batch_size, seq_len, seq_len]
# Apply mask if provided (useful for preventing attention to padding)
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
# Apply softmax to get attention weights
attention = F.softmax(energy, dim=2) # [batch_size, seq_len, seq_len]
# Multiply by values
out = torch.bmm(attention, V) # [batch_size, seq_len, embed_size]
return out, attention
Let's Try Self-Attention:
# Sample data
batch_size = 2
seq_len = 4
embed_size = 8
# Create a sample input sequence
x = torch.randn(batch_size, seq_len, embed_size)
# Initialize self-attention module
self_attention = SelfAttention(embed_size)
# Apply self-attention
output, attention_matrix = self_attention(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("Attention matrix shape:", attention_matrix.shape)
print("\nAttention matrix for first sequence:")
print(attention_matrix[0].round(decimals=2))
Output:
Input shape: torch.Size([2, 4, 8])
Output shape: torch.Size([2, 4, 8])
Attention matrix shape: torch.Size([2, 4, 4])
Attention matrix for first sequence:
tensor([[0.30, 0.25, 0.22, 0.23],
[0.22, 0.28, 0.26, 0.24],
[0.21, 0.27, 0.27, 0.25],
[0.23, 0.26, 0.24, 0.27]], grad_fn=<RoundBackward>)
Multi-Head Attention: Attention in Parallel
Multi-head attention allows the model to attend to information from different representation subspaces. Essentially, it's running multiple self-attention operations in parallel and then combining the results.
Here's how we can implement it:
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, n_heads):
super(MultiHeadAttention, self).__init__()
assert embed_size % n_heads == 0, "Embedding size must be divisible by number of heads"
self.embed_size = embed_size
self.n_heads = n_heads
self.head_dim = embed_size // n_heads
# Linear layers for Q, K, V projections
self.q_linear = nn.Linear(embed_size, embed_size)
self.k_linear = nn.Linear(embed_size, embed_size)
self.v_linear = nn.Linear(embed_size, embed_size)
# Output projection
self.fc_out = nn.Linear(embed_size, embed_size)
# Scaling factor
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
# Linear transformations
Q = self.q_linear(query) # [batch_size, query_len, embed_size]
K = self.k_linear(key) # [batch_size, key_len, embed_size]
V = self.v_linear(value) # [batch_size, value_len, embed_size]
# Reshape for multi-head attention
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
# Now Q is [batch_size, n_heads, query_len, head_dim]
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
# Now K is [batch_size, n_heads, key_len, head_dim]
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
# Now V is [batch_size, n_heads, value_len, head_dim]
# Calculate attention scores
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale.to(query.device)
# energy: [batch_size, n_heads, query_len, key_len]
# Apply mask if provided
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
# Apply softmax
attention = torch.softmax(energy, dim=-1)
# attention: [batch_size, n_heads, query_len, key_len]
# Get weighted sum
x = torch.matmul(attention, V)
# x: [batch_size, n_heads, query_len, head_dim]
# Reshape to combine all heads
x = x.permute(0, 2, 1, 3).contiguous()
# x: [batch_size, query_len, n_heads, head_dim]
x = x.view(batch_size, -1, self.embed_size)
# x: [batch_size, query_len, embed_size]
# Final linear layer
x = self.fc_out(x)
return x, attention
Let's Try Multi-Head Attention:
# Sample data
batch_size = 2
seq_len = 5
embed_size = 8
n_heads = 2
# Create sample input sequences
query = torch.randn(batch_size, seq_len, embed_size)
key = torch.randn(batch_size, seq_len, embed_size)
value = torch.randn(batch_size, seq_len, embed_size)
# Initialize multi-head attention module
multihead_attention = MultiHeadAttention(embed_size, n_heads)
# Apply multi-head attention
output, attention = multihead_attention(query, key, value)
print("Query shape:", query.shape)
print("Output shape:", output.shape)
print("Attention shape:", attention.shape)
Output:
Query shape: torch.Size([2, 5, 8])
Output shape: torch.Size([2, 5, 8])
Attention shape: torch.Size([2, 2, 5, 5])
Practical Example: Attention-Based Text Classification
Now, let's build a practical example: a text classifier using attention to identify sentiment in reviews. We'll use attention to help our model focus on the most important words in each review.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class AttentionTextClassifier(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_classes, dropout=0.5):
super(AttentionTextClassifier, self).__init__()
# Embedding layer
self.embedding = nn.Embedding(vocab_size, embed_size)
# Bidirectional LSTM
self.lstm = nn.LSTM(
embed_size,
hidden_size,
bidirectional=True,
batch_first=True
)
# Attention layer
self.attention = nn.Linear(hidden_size * 2, 1)
# Output layer
self.fc = nn.Linear(hidden_size * 2, num_classes)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, text, text_lengths):
# text shape: [batch_size, seq_len]
# Get embeddings
embedded = self.embedding(text) # [batch_size, seq_len, embed_size]
embedded = self.dropout(embedded)
# Pack sequence for LSTM
packed = pack_padded_sequence(
embedded,
text_lengths.cpu(),
batch_first=True,
enforce_sorted=False
)
# Pass through LSTM
packed_output, (hidden, cell) = self.lstm(packed)
output, output_lengths = pad_packed_sequence(packed_output, batch_first=True)
# output shape: [batch_size, seq_len, hidden_size*2]
# Calculate attention weights
attention_weights = self.attention(output) # [batch_size, seq_len, 1]
attention_weights = F.softmax(attention_weights.squeeze(2), dim=1) # [batch_size, seq_len]
# Create attention mask for padding
mask = torch.arange(output.shape[1])[None, :].to(text.device) < text_lengths[:, None]
attention_weights = attention_weights * mask.float()
# Normalize weights for padded sequences
attention_weights = attention_weights / (attention_weights.sum(dim=1, keepdim=True) + 1e-9)
# Apply attention weights
context = torch.bmm(
attention_weights.unsqueeze(1), # [batch_size, 1, seq_len]
output # [batch_size, seq_len, hidden_size*2]
) # [batch_size, 1, hidden_size*2]
context = context.squeeze(1) # [batch_size, hidden_size*2]
# Final classification
output = self.fc(self.dropout(context)) # [batch_size, num_classes]
return output, attention_weights
How to Use the Attention-Based Classifier:
# Example parameters
vocab_size = 10000
embed_size = 100
hidden_size = 128
num_classes = 2 # Binary classification (positive/negative)
# Create model
model = AttentionTextClassifier(vocab_size, embed_size, hidden_size, num_classes)
# Sample batch
batch_size = 3
max_seq_len = 10
# Sample data (in a real-world scenario, this would be tokenized text)
text = torch.randint(0, vocab_size, (batch_size, max_seq_len))
# Variable length sequences
text_lengths = torch.tensor([10, 7, 5])
# Forward pass
predictions, attention_weights = model(text, text_lengths)
print("Input text shape:", text.shape)
print("Predictions shape:", predictions.shape)
print("Attention weights shape:", attention_weights.shape)
# Visualize attention weights for the first sequence
print("\nAttention weights for first sequence:")
print(attention_weights[0][:text_lengths[0]].detach().numpy().round(3))
Output:
Input text shape: torch.Size([3, 10])
Predictions shape: torch.Size([3, 2])
Attention weights shape: torch.Size([3, 10])
Attention weights for first sequence:
[0.103 0.104 0.091 0.098 0.099 0.101 0.102 0.103 0.1 0.099]
Real-World Application: Building a Neural Machine Translation System
Attention mechanisms are particularly powerful in machine translation. Let's sketch out a simplified version of a Neural Machine Translation (NMT) model with attention:
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
def forward(self, input_seq, hidden=None):
embedded = self.embedding(input_seq)
output, hidden = self.gru(embedded, hidden)
return output, hidden
class AttentionDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size):
super(AttentionDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.embedding = nn.Embedding(output_size, hidden_size)
self.attention = nn.Linear(hidden_size * 2, 1)
self.gru = nn.GRU(hidden_size * 2, hidden_size, batch_first=True)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, input_step, hidden, encoder_outputs):
# input_step shape: [batch_size, 1]
# hidden shape: [1, batch_size, hidden_size]
# encoder_outputs shape: [batch_size, seq_len, hidden_size]
# Get embedding of current input word
embedded = self.embedding(input_step) # [batch_size, 1, hidden_size]
# Calculate attention weights
hidden_expanded = hidden.transpose(0, 1) # [batch_size, 1, hidden_size]
hidden_expanded = hidden_expanded.expand(-1, encoder_outputs.size(1), -1)
# Concatenate hidden state with each encoder output
attn_inputs = torch.cat((encoder_outputs, hidden_expanded), dim=2)
# Calculate attention weights
attn_weights = self.attention(attn_inputs) # [batch_size, seq_len, 1]
attn_weights = F.softmax(attn_weights.squeeze(-1), dim=1) # [batch_size, seq_len]
# Apply attention weights to encoder outputs
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) # [batch_size, 1, hidden_size]
# Combine embedded input word and context vector
gru_input = torch.cat((embedded, context), dim=2) # [batch_size, 1, hidden_size*2]
# GRU output
output, hidden = self.gru(gru_input, hidden) # output: [batch_size, 1, hidden_size]
# Final output layer
output = self.out(output.squeeze(1)) # [batch_size, output_size]
output = F.log_softmax(output, dim=1)
return output, hidden, attn_weights
Here's how you would use this model for neural machine translation:
def translate_sentence(input_sentence, encoder, decoder, input_vocab, output_vocab, max_length=50):
# Convert sentence to tensor
input_indices = [input_vocab[word] for word in input_sentence.split()]
input_tensor = torch.LongTensor(input_indices).unsqueeze(0) # [1, seq_len]
# Encode the sentence
encoder_outputs, encoder_hidden = encoder(input_tensor)
# Start with SOS token
decoder_input = torch.LongTensor([[output_vocab['
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)