How to Implement Self-Attention in PyTorch

Self-attention is the core mechanism that powers transformers, enabling models like BERT, GPT, and Vision Transformers to understand relationships between elements in a sequence. Unlike recurrent...

Key Insights

  • Self-attention computes relationships between all positions in a sequence by transforming inputs into Query, Key, and Value matrices, then using scaled dot-product operations to weight each position’s contribution to the output.
  • Multi-head attention runs several attention mechanisms in parallel with different learned projections, allowing the model to capture diverse relational patterns simultaneously across different representation subspaces.
  • PyTorch’s built-in torch.nn.MultiheadAttention is optimized for production use, but implementing attention from scratch reveals the elegant simplicity behind transformer architectures and helps debug attention-related issues.

Introduction to Self-Attention

Self-attention is the core mechanism that powers transformers, enabling models like BERT, GPT, and Vision Transformers to understand relationships between elements in a sequence. Unlike recurrent networks that process sequences sequentially, self-attention computes interactions between all positions simultaneously.

The fundamental idea: each token in a sequence attends to every other token (including itself) to determine how much each position should influence its representation. For the sentence “The cat sat on the mat,” self-attention helps the model learn that “cat” and “sat” are closely related, while “the” might attend strongly to the noun it modifies.

This parallel computation makes transformers both faster to train and more capable of capturing long-range dependencies than RNNs or LSTMs.

Mathematical Foundation

Self-attention uses three learned transformations of the input: Query (Q), Key (K), and Value (V). Think of it like a database lookup: queries search for relevant keys, and when matches are found, corresponding values are retrieved.

The attention formula is:

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Let’s break this down with a minimal NumPy example:

import numpy as np

# 3 tokens, embedding dimension of 4
X = np.random.randn(3, 4)
print("Input shape:", X.shape)  # (3, 4)

# Initialize Q, K, V projection matrices
W_q = np.random.randn(4, 4)
W_k = np.random.randn(4, 4)
W_v = np.random.randn(4, 4)

# Project input to Q, K, V
Q = X @ W_q  # (3, 4)
K = X @ W_k  # (3, 4)
V = X @ W_v  # (3, 4)

# Compute attention scores
d_k = Q.shape[-1]
scores = Q @ K.T / np.sqrt(d_k)  # (3, 3)
print("Attention scores:\n", scores)

# Apply softmax to get attention weights
attention_weights = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
print("Attention weights:\n", attention_weights)

# Compute weighted sum of values
output = attention_weights @ V  # (3, 4)
print("Output shape:", output.shape)

Each row in the attention weights matrix shows how much each token attends to every other token. The softmax ensures weights sum to 1, creating a weighted average of the value vectors.

Basic Self-Attention Implementation

Here’s a clean PyTorch implementation:

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, embed_dim)
        B, T, C = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x)  # (B, T, C)
        K = self.W_k(x)  # (B, T, C)
        V = self.W_v(x)  # (B, T, C)
        
        # Compute attention scores
        scores = Q @ K.transpose(-2, -1)  # (B, T, T)
        scores = scores / (C ** 0.5)  # scale
        
        # Apply softmax
        attn_weights = torch.softmax(scores, dim=-1)  # (B, T, T)
        
        # Weighted sum of values
        output = attn_weights @ V  # (B, T, C)
        
        return output, attn_weights

# Test it
batch_size, seq_len, embed_dim = 2, 5, 64
x = torch.randn(batch_size, seq_len, embed_dim)

attn = SelfAttention(embed_dim)
output, weights = attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

Scaled Dot-Product Attention

The scaling factor 1/√d_k prevents the dot products from growing too large, which would push the softmax into regions with extremely small gradients. Let’s add masking support for padding and causal attention:

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, Q, K, V, mask=None):
        # Q, K, V: (B, T, d_k)
        d_k = Q.size(-1)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        
        # Apply mask (for padding or causal attention)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax and dropout
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights

# Create a causal mask (for autoregressive models)
seq_len = 5
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
print("Causal mask:\n", causal_mask.int())

# Test with mask
Q = K = V = torch.randn(1, seq_len, 64)
attn = ScaledDotProductAttention()
output, weights = attn(Q, K, V, mask=causal_mask)
print("Masked attention weights:\n", weights[0].detach().numpy().round(3))

Multi-Head Attention

Multi-head attention runs multiple attention operations in parallel, each with different learned projections. This allows the model to attend to different aspects of the input simultaneously:

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Single projection for all heads (more efficient)
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        B, T, C = x.shape
        
        # Project and split into multiple heads
        Q = self.W_q(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, num_heads, T, head_dim)
        K = self.W_k(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = attn_weights @ V  # (B, num_heads, T, head_dim)
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        
        # Final linear projection
        output = self.W_o(attn_output)
        
        return output, attn_weights

# Test multi-head attention
mha = MultiHeadAttention(embed_dim=512, num_heads=8)
x = torch.randn(2, 10, 512)
output, weights = mha(x)

print(f"Output shape: {output.shape}")
print(f"Attention weights shape (per head): {weights.shape}")

Practical Integration

Here’s a complete encoder block using self-attention in a sequence classification model:

class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        # Self-attention with residual
        attn_out, _ = self.attention(x)
        x = self.norm1(x + attn_out)
        
        # Feedforward with residual
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        
        return x

class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoding = nn.Parameter(torch.randn(1, 512, embed_dim))
        
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, embed_dim * 4)
            for _ in range(num_layers)
        ])
        
        self.classifier = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        B, T = x.shape
        
        x = self.embedding(x) + self.pos_encoding[:, :T, :]
        
        for block in self.blocks:
            x = block(x)
        
        # Global average pooling
        x = x.mean(dim=1)
        
        return self.classifier(x)

# Quick training example
model = SimpleTransformer(vocab_size=1000, embed_dim=128, num_heads=4, num_layers=2, num_classes=2)
x = torch.randint(0, 1000, (8, 50))  # batch of sequences
logits = model(x)
print(f"Logits shape: {logits.shape}")

Optimization and Best Practices

For production code, use PyTorch’s optimized implementation:

# PyTorch's built-in multi-head attention
builtin_mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)

# Compare performance
import time

x = torch.randn(32, 100, 512)

# Custom implementation
start = time.time()
custom_mha = MultiHeadAttention(512, 8)
for _ in range(100):
    _ = custom_mha(x)
custom_time = time.time() - start

# Built-in implementation
start = time.time()
for _ in range(100):
    _ = builtin_mha(x, x, x)
builtin_time = time.time() - start

print(f"Custom: {custom_time:.3f}s")
print(f"Built-in: {builtin_time:.3f}s")
print(f"Speedup: {custom_time/builtin_time:.2f}x")

Key optimizations to consider:

  1. Use torch.nn.functional.scaled_dot_product_attention (PyTorch 2.0+) for flash attention and memory efficiency
  2. Enable torch.compile() for JIT optimization
  3. Use mixed precision training with torch.cuda.amp for faster computation
  4. Implement gradient checkpointing for very deep models to trade compute for memory

The built-in implementation will typically be 2-3x faster due to fused CUDA kernels and optimized memory layouts. However, custom implementations are invaluable for research, debugging attention patterns, and implementing novel attention variants like sparse attention or linear attention mechanisms.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.