How to Implement Self-Attention in PyTorch
Self-attention is the core mechanism that powers transformers, enabling models like BERT, GPT, and Vision Transformers to understand relationships between elements in a sequence. Unlike recurrent...
Key Insights
- Self-attention computes relationships between all positions in a sequence by transforming inputs into Query, Key, and Value matrices, then using scaled dot-product operations to weight each position’s contribution to the output.
- Multi-head attention runs several attention mechanisms in parallel with different learned projections, allowing the model to capture diverse relational patterns simultaneously across different representation subspaces.
- PyTorch’s built-in
torch.nn.MultiheadAttentionis optimized for production use, but implementing attention from scratch reveals the elegant simplicity behind transformer architectures and helps debug attention-related issues.
Introduction to Self-Attention
Self-attention is the core mechanism that powers transformers, enabling models like BERT, GPT, and Vision Transformers to understand relationships between elements in a sequence. Unlike recurrent networks that process sequences sequentially, self-attention computes interactions between all positions simultaneously.
The fundamental idea: each token in a sequence attends to every other token (including itself) to determine how much each position should influence its representation. For the sentence “The cat sat on the mat,” self-attention helps the model learn that “cat” and “sat” are closely related, while “the” might attend strongly to the noun it modifies.
This parallel computation makes transformers both faster to train and more capable of capturing long-range dependencies than RNNs or LSTMs.
Mathematical Foundation
Self-attention uses three learned transformations of the input: Query (Q), Key (K), and Value (V). Think of it like a database lookup: queries search for relevant keys, and when matches are found, corresponding values are retrieved.
The attention formula is:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Let’s break this down with a minimal NumPy example:
import numpy as np
# 3 tokens, embedding dimension of 4
X = np.random.randn(3, 4)
print("Input shape:", X.shape) # (3, 4)
# Initialize Q, K, V projection matrices
W_q = np.random.randn(4, 4)
W_k = np.random.randn(4, 4)
W_v = np.random.randn(4, 4)
# Project input to Q, K, V
Q = X @ W_q # (3, 4)
K = X @ W_k # (3, 4)
V = X @ W_v # (3, 4)
# Compute attention scores
d_k = Q.shape[-1]
scores = Q @ K.T / np.sqrt(d_k) # (3, 3)
print("Attention scores:\n", scores)
# Apply softmax to get attention weights
attention_weights = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
print("Attention weights:\n", attention_weights)
# Compute weighted sum of values
output = attention_weights @ V # (3, 4)
print("Output shape:", output.shape)
Each row in the attention weights matrix shows how much each token attends to every other token. The softmax ensures weights sum to 1, creating a weighted average of the value vectors.
Basic Self-Attention Implementation
Here’s a clean PyTorch implementation:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.embed_dim = embed_dim
# Linear projections for Q, K, V
self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, x):
# x shape: (batch_size, seq_len, embed_dim)
B, T, C = x.shape
# Project to Q, K, V
Q = self.W_q(x) # (B, T, C)
K = self.W_k(x) # (B, T, C)
V = self.W_v(x) # (B, T, C)
# Compute attention scores
scores = Q @ K.transpose(-2, -1) # (B, T, T)
scores = scores / (C ** 0.5) # scale
# Apply softmax
attn_weights = torch.softmax(scores, dim=-1) # (B, T, T)
# Weighted sum of values
output = attn_weights @ V # (B, T, C)
return output, attn_weights
# Test it
batch_size, seq_len, embed_dim = 2, 5, 64
x = torch.randn(batch_size, seq_len, embed_dim)
attn = SelfAttention(embed_dim)
output, weights = attn(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
Scaled Dot-Product Attention
The scaling factor 1/√d_k prevents the dot products from growing too large, which would push the softmax into regions with extremely small gradients. Let’s add masking support for padding and causal attention:
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
# Q, K, V: (B, T, d_k)
d_k = Q.size(-1)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# Apply mask (for padding or causal attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax and dropout
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
output = torch.matmul(attn_weights, V)
return output, attn_weights
# Create a causal mask (for autoregressive models)
seq_len = 5
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
print("Causal mask:\n", causal_mask.int())
# Test with mask
Q = K = V = torch.randn(1, seq_len, 64)
attn = ScaledDotProductAttention()
output, weights = attn(Q, K, V, mask=causal_mask)
print("Masked attention weights:\n", weights[0].detach().numpy().round(3))
Multi-Head Attention
Multi-head attention runs multiple attention operations in parallel, each with different learned projections. This allows the model to attend to different aspects of the input simultaneously:
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Single projection for all heads (more efficient)
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, T, C = x.shape
# Project and split into multiple heads
Q = self.W_q(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, num_heads, T, head_dim)
K = self.W_k(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = attn_weights @ V # (B, num_heads, T, head_dim)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
# Final linear projection
output = self.W_o(attn_output)
return output, attn_weights
# Test multi-head attention
mha = MultiHeadAttention(embed_dim=512, num_heads=8)
x = torch.randn(2, 10, 512)
output, weights = mha(x)
print(f"Output shape: {output.shape}")
print(f"Attention weights shape (per head): {weights.shape}")
Practical Integration
Here’s a complete encoder block using self-attention in a sequence classification model:
class TransformerEncoderBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(ff_dim, embed_dim),
nn.Dropout(dropout)
)
def forward(self, x):
# Self-attention with residual
attn_out, _ = self.attention(x)
x = self.norm1(x + attn_out)
# Feedforward with residual
ff_out = self.ff(x)
x = self.norm2(x + ff_out)
return x
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.pos_encoding = nn.Parameter(torch.randn(1, 512, embed_dim))
self.blocks = nn.ModuleList([
TransformerEncoderBlock(embed_dim, num_heads, embed_dim * 4)
for _ in range(num_layers)
])
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B, T = x.shape
x = self.embedding(x) + self.pos_encoding[:, :T, :]
for block in self.blocks:
x = block(x)
# Global average pooling
x = x.mean(dim=1)
return self.classifier(x)
# Quick training example
model = SimpleTransformer(vocab_size=1000, embed_dim=128, num_heads=4, num_layers=2, num_classes=2)
x = torch.randint(0, 1000, (8, 50)) # batch of sequences
logits = model(x)
print(f"Logits shape: {logits.shape}")
Optimization and Best Practices
For production code, use PyTorch’s optimized implementation:
# PyTorch's built-in multi-head attention
builtin_mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
# Compare performance
import time
x = torch.randn(32, 100, 512)
# Custom implementation
start = time.time()
custom_mha = MultiHeadAttention(512, 8)
for _ in range(100):
_ = custom_mha(x)
custom_time = time.time() - start
# Built-in implementation
start = time.time()
for _ in range(100):
_ = builtin_mha(x, x, x)
builtin_time = time.time() - start
print(f"Custom: {custom_time:.3f}s")
print(f"Built-in: {builtin_time:.3f}s")
print(f"Speedup: {custom_time/builtin_time:.2f}x")
Key optimizations to consider:
- Use
torch.nn.functional.scaled_dot_product_attention(PyTorch 2.0+) for flash attention and memory efficiency - Enable
torch.compile()for JIT optimization - Use mixed precision training with
torch.cuda.ampfor faster computation - Implement gradient checkpointing for very deep models to trade compute for memory
The built-in implementation will typically be 2-3x faster due to fused CUDA kernels and optimized memory layouts. However, custom implementations are invaluable for research, debugging attention patterns, and implementing novel attention variants like sparse attention or linear attention mechanisms.