How to Implement Attention Mechanism in PyTorch
Attention mechanisms revolutionized deep learning by solving a fundamental problem: how do we let models focus on the most relevant parts of their input? Before attention, sequence models like RNNs...
Key Insights
- Attention mechanisms let models dynamically focus on relevant parts of input sequences by computing weighted combinations based on query-key similarity scores
- The scaled dot-product attention formula (softmax(QK^T/√d_k)V) is the foundation for modern transformers and can be implemented in just a few lines of PyTorch
- Multi-head attention runs multiple attention operations in parallel to capture different representation subspaces, significantly improving model expressiveness over single-head variants
Introduction to Attention Mechanisms
Attention mechanisms revolutionized deep learning by solving a fundamental problem: how do we let models focus on the most relevant parts of their input? Before attention, sequence models like RNNs struggled with long-range dependencies because they compressed entire sequences into fixed-size vectors. Attention changed this by allowing models to dynamically compute weighted combinations of input elements.
The core intuition is simple. When translating “The cat sat on the mat” to French, the model should focus heavily on “cat” when generating “chat.” Attention provides exactly this capability—a learned mechanism to assign importance scores to different input positions.
Today, attention is everywhere: transformers, vision models, speech recognition, and even reinforcement learning. Understanding how to implement it from scratch gives you the foundation to work with modern architectures effectively.
Mathematical Foundation
Attention boils down to three components: queries (Q), keys (K), and values (V). Think of it like a database lookup. Your query searches through keys to find matches, then retrieves corresponding values. The attention score determines how much each value contributes to the output.
The scaled dot-product attention formula is:
Attention(Q, K, V) = softmax(QK^T / √d_k)V
Here, d_k is the dimension of the key vectors. We divide by √d_k to prevent the dot products from growing too large, which would push softmax into regions with tiny gradients.
Let’s implement this from scratch:
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Args:
query: Tensor of shape (batch_size, seq_len, d_k)
key: Tensor of shape (batch_size, seq_len, d_k)
value: Tensor of shape (batch_size, seq_len, d_v)
mask: Optional tensor for masking positions
Returns:
output: Weighted values (batch_size, seq_len, d_v)
attention_weights: Attention scores (batch_size, seq_len, seq_len)
"""
d_k = query.size(-1)
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# Apply mask if provided (useful for causal attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Apply softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Compute weighted sum of values
output = torch.matmul(attention_weights, value)
return output, attention_weights
# Test it
batch_size, seq_len, d_model = 2, 4, 8
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}") # (2, 4, 8)
print(f"Attention weights shape: {weights.shape}") # (2, 4, 4)
print(f"Weights sum to 1: {weights.sum(dim=-1)}") # Each row sums to 1
Building a Simple Attention Layer
Now let’s wrap this in a proper PyTorch module with learnable parameters. An attention layer needs linear transformations to project inputs into Q, K, and V spaces.
import torch.nn as nn
class AttentionLayer(nn.Module):
def __init__(self, d_model):
"""
Args:
d_model: Dimension of input embeddings
"""
super(AttentionLayer, self).__init__()
self.d_model = d_model
# Linear layers to project to Q, K, V
self.query_proj = nn.Linear(d_model, d_model)
self.key_proj = nn.Linear(d_model, d_model)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
# Project inputs to Q, K, V
Q = self.query_proj(query)
K = self.key_proj(key)
V = self.value_proj(value)
# Apply attention
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# Final linear projection
output = self.output_proj(attn_output)
return output, attn_weights
# Visualize attention weights
attention_layer = AttentionLayer(d_model=64)
x = torch.randn(1, 5, 64) # Single sequence of length 5
output, weights = attention_layer(x, x, x)
print("Attention weight matrix:")
print(weights[0].detach().numpy().round(2))
The attention weight matrix shows how much each position attends to every other position. Rows sum to 1.0, and higher values indicate stronger connections.
Self-Attention Implementation
Self-attention is a special case where queries, keys, and values all come from the same input sequence. This lets the model capture relationships between different positions in the same sequence—crucial for language understanding.
class SelfAttention(nn.Module):
def __init__(self, d_model, dropout=0.1):
super(SelfAttention, self).__init__()
self.attention = AttentionLayer(d_model)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Self-attention: Q, K, V all come from x
attn_output, attn_weights = self.attention(x, x, x, mask)
# Residual connection and layer norm
x = self.layer_norm(x + self.dropout(attn_output))
return x, attn_weights
# Apply to sequence data
vocab_size, seq_len, d_model = 1000, 10, 128
embedding = nn.Embedding(vocab_size, d_model)
# Simulate tokenized text
tokens = torch.randint(0, vocab_size, (2, seq_len))
embedded = embedding(tokens)
self_attn = SelfAttention(d_model)
output, weights = self_attn(embedded)
print(f"Self-attention output shape: {output.shape}") # (2, 10, 128)
Notice the residual connection and layer normalization—these are standard practices that stabilize training and improve gradient flow.
Multi-Head Attention
Single attention heads have limited capacity. Multi-head attention runs multiple attention operations in parallel, each learning different aspects of the relationships. Think of it as having multiple “attention experts” that specialize in different patterns.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.query_proj = nn.Linear(d_model, d_model)
self.key_proj = nn.Linear(d_model, d_model)
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x):
"""Split last dimension into (num_heads, d_k)"""
batch_size, seq_len, d_model = x.size()
return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
"""Inverse of split_heads"""
batch_size, num_heads, seq_len, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections
Q = self.split_heads(self.query_proj(query))
K = self.split_heads(self.key_proj(key))
V = self.split_heads(self.value_proj(value))
# Scaled dot-product attention for each head
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, V)
# Combine heads and apply final linear projection
output = self.combine_heads(attn_output)
output = self.output_proj(output)
return output, attn_weights
# Compare single vs multi-head
single_head = MultiHeadAttention(d_model=128, num_heads=1)
multi_head = MultiHeadAttention(d_model=128, num_heads=8)
x = torch.randn(2, 10, 128)
out_single, _ = single_head(x, x, x)
out_multi, _ = multi_head(x, x, x)
print(f"Single head output: {out_single.shape}")
print(f"Multi head output: {out_multi.shape}")
Multi-head attention with 8 heads means 8 parallel attention operations, each working in a 16-dimensional subspace (128/8). This parallelism captures richer patterns than a single 128-dimensional attention.
Practical Application: Sequence Classification
Let’s build a complete sentiment classifier using our attention layers. This demonstrates how attention fits into real models.
class AttentionClassifier(nn.Module):
def __init__(self, vocab_size, d_model=128, num_heads=4, num_classes=2):
super(AttentionClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.mha = MultiHeadAttention(d_model, num_heads)
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(d_model, num_classes)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
# x: (batch_size, seq_len)
x = self.embedding(x) # (batch_size, seq_len, d_model)
# Self-attention
attn_out, _ = self.mha(x, x, x)
x = self.layer_norm(x + self.dropout(attn_out))
# Global average pooling
x = x.mean(dim=1) # (batch_size, d_model)
# Classification
logits = self.fc(x)
return logits
# Training example
model = AttentionClassifier(vocab_size=5000, d_model=128, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# Dummy training data
train_texts = torch.randint(0, 5000, (32, 50)) # 32 sequences of length 50
train_labels = torch.randint(0, 2, (32,)) # Binary labels
# Training step
model.train()
optimizer.zero_grad()
logits = model(train_texts)
loss = criterion(logits, train_labels)
loss.backward()
optimizer.step()
print(f"Training loss: {loss.item():.4f}")
This classifier uses attention to capture relationships between words, pools the sequence representations, and classifies sentiment. For production use, you’d add positional encodings, multiple attention layers, and train on real data like IMDB reviews.
Conclusion and Next Steps
You now understand how to implement attention mechanisms from first principles in PyTorch. The scaled dot-product attention formula is surprisingly simple, yet it powers state-of-the-art models across domains.
Key takeaways: attention computes weighted combinations based on learned similarity scores, self-attention captures intra-sequence relationships, and multi-head attention increases model capacity through parallel attention operations.
For production code, use PyTorch’s built-in nn.MultiheadAttention—it’s optimized and battle-tested. But understanding the implementation details helps you debug issues, customize architectures, and stay current as the field evolves.
Next steps: explore causal masking for autoregressive models, cross-attention for encoder-decoder architectures, relative positional encodings, and attention optimization techniques like flash attention. The attention mechanism is your foundation for working with transformers, and you’re now equipped to build and modify them confidently.