How to Implement Attention Mechanism in PyTorch

Attention mechanisms revolutionized deep learning by solving a fundamental problem: how do we let models focus on the most relevant parts of their input? Before attention, sequence models like RNNs...

Key Insights

  • Attention mechanisms let models dynamically focus on relevant parts of input sequences by computing weighted combinations based on query-key similarity scores
  • The scaled dot-product attention formula (softmax(QK^T/√d_k)V) is the foundation for modern transformers and can be implemented in just a few lines of PyTorch
  • Multi-head attention runs multiple attention operations in parallel to capture different representation subspaces, significantly improving model expressiveness over single-head variants

Introduction to Attention Mechanisms

Attention mechanisms revolutionized deep learning by solving a fundamental problem: how do we let models focus on the most relevant parts of their input? Before attention, sequence models like RNNs struggled with long-range dependencies because they compressed entire sequences into fixed-size vectors. Attention changed this by allowing models to dynamically compute weighted combinations of input elements.

The core intuition is simple. When translating “The cat sat on the mat” to French, the model should focus heavily on “cat” when generating “chat.” Attention provides exactly this capability—a learned mechanism to assign importance scores to different input positions.

Today, attention is everywhere: transformers, vision models, speech recognition, and even reinforcement learning. Understanding how to implement it from scratch gives you the foundation to work with modern architectures effectively.

Mathematical Foundation

Attention boils down to three components: queries (Q), keys (K), and values (V). Think of it like a database lookup. Your query searches through keys to find matches, then retrieves corresponding values. The attention score determines how much each value contributes to the output.

The scaled dot-product attention formula is:

Attention(Q, K, V) = softmax(QK^T / √d_k)V

Here, d_k is the dimension of the key vectors. We divide by √d_k to prevent the dot products from growing too large, which would push softmax into regions with tiny gradients.

Let’s implement this from scratch:

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Args:
        query: Tensor of shape (batch_size, seq_len, d_k)
        key: Tensor of shape (batch_size, seq_len, d_k)
        value: Tensor of shape (batch_size, seq_len, d_v)
        mask: Optional tensor for masking positions
    
    Returns:
        output: Weighted values (batch_size, seq_len, d_v)
        attention_weights: Attention scores (batch_size, seq_len, seq_len)
    """
    d_k = query.size(-1)
    
    # Compute attention scores
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # Apply mask if provided (useful for causal attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Compute weighted sum of values
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

# Test it
batch_size, seq_len, d_model = 2, 4, 8
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")  # (2, 4, 8)
print(f"Attention weights shape: {weights.shape}")  # (2, 4, 4)
print(f"Weights sum to 1: {weights.sum(dim=-1)}")  # Each row sums to 1

Building a Simple Attention Layer

Now let’s wrap this in a proper PyTorch module with learnable parameters. An attention layer needs linear transformations to project inputs into Q, K, and V spaces.

import torch.nn as nn

class AttentionLayer(nn.Module):
    def __init__(self, d_model):
        """
        Args:
            d_model: Dimension of input embeddings
        """
        super(AttentionLayer, self).__init__()
        self.d_model = d_model
        
        # Linear layers to project to Q, K, V
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        
        self.output_proj = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        # Project inputs to Q, K, V
        Q = self.query_proj(query)
        K = self.key_proj(key)
        V = self.value_proj(value)
        
        # Apply attention
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Final linear projection
        output = self.output_proj(attn_output)
        
        return output, attn_weights

# Visualize attention weights
attention_layer = AttentionLayer(d_model=64)
x = torch.randn(1, 5, 64)  # Single sequence of length 5

output, weights = attention_layer(x, x, x)
print("Attention weight matrix:")
print(weights[0].detach().numpy().round(2))

The attention weight matrix shows how much each position attends to every other position. Rows sum to 1.0, and higher values indicate stronger connections.

Self-Attention Implementation

Self-attention is a special case where queries, keys, and values all come from the same input sequence. This lets the model capture relationships between different positions in the same sequence—crucial for language understanding.

class SelfAttention(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(SelfAttention, self).__init__()
        self.attention = AttentionLayer(d_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        # Self-attention: Q, K, V all come from x
        attn_output, attn_weights = self.attention(x, x, x, mask)
        
        # Residual connection and layer norm
        x = self.layer_norm(x + self.dropout(attn_output))
        
        return x, attn_weights

# Apply to sequence data
vocab_size, seq_len, d_model = 1000, 10, 128
embedding = nn.Embedding(vocab_size, d_model)

# Simulate tokenized text
tokens = torch.randint(0, vocab_size, (2, seq_len))
embedded = embedding(tokens)

self_attn = SelfAttention(d_model)
output, weights = self_attn(embedded)
print(f"Self-attention output shape: {output.shape}")  # (2, 10, 128)

Notice the residual connection and layer normalization—these are standard practices that stabilize training and improve gradient flow.

Multi-Head Attention

Single attention heads have limited capacity. Multi-head attention runs multiple attention operations in parallel, each learning different aspects of the relationships. Think of it as having multiple “attention experts” that specialize in different patterns.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x):
        """Split last dimension into (num_heads, d_k)"""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """Inverse of split_heads"""
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections
        Q = self.split_heads(self.query_proj(query))
        K = self.split_heads(self.key_proj(key))
        V = self.split_heads(self.value_proj(value))
        
        # Scaled dot-product attention for each head
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, V)
        
        # Combine heads and apply final linear projection
        output = self.combine_heads(attn_output)
        output = self.output_proj(output)
        
        return output, attn_weights

# Compare single vs multi-head
single_head = MultiHeadAttention(d_model=128, num_heads=1)
multi_head = MultiHeadAttention(d_model=128, num_heads=8)

x = torch.randn(2, 10, 128)
out_single, _ = single_head(x, x, x)
out_multi, _ = multi_head(x, x, x)

print(f"Single head output: {out_single.shape}")
print(f"Multi head output: {out_multi.shape}")

Multi-head attention with 8 heads means 8 parallel attention operations, each working in a 16-dimensional subspace (128/8). This parallelism captures richer patterns than a single 128-dimensional attention.

Practical Application: Sequence Classification

Let’s build a complete sentiment classifier using our attention layers. This demonstrates how attention fits into real models.

class AttentionClassifier(nn.Module):
    def __init__(self, vocab_size, d_model=128, num_heads=4, num_classes=2):
        super(AttentionClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.layer_norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, num_classes)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        # x: (batch_size, seq_len)
        x = self.embedding(x)  # (batch_size, seq_len, d_model)
        
        # Self-attention
        attn_out, _ = self.mha(x, x, x)
        x = self.layer_norm(x + self.dropout(attn_out))
        
        # Global average pooling
        x = x.mean(dim=1)  # (batch_size, d_model)
        
        # Classification
        logits = self.fc(x)
        return logits

# Training example
model = AttentionClassifier(vocab_size=5000, d_model=128, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Dummy training data
train_texts = torch.randint(0, 5000, (32, 50))  # 32 sequences of length 50
train_labels = torch.randint(0, 2, (32,))  # Binary labels

# Training step
model.train()
optimizer.zero_grad()
logits = model(train_texts)
loss = criterion(logits, train_labels)
loss.backward()
optimizer.step()

print(f"Training loss: {loss.item():.4f}")

This classifier uses attention to capture relationships between words, pools the sequence representations, and classifies sentiment. For production use, you’d add positional encodings, multiple attention layers, and train on real data like IMDB reviews.

Conclusion and Next Steps

You now understand how to implement attention mechanisms from first principles in PyTorch. The scaled dot-product attention formula is surprisingly simple, yet it powers state-of-the-art models across domains.

Key takeaways: attention computes weighted combinations based on learned similarity scores, self-attention captures intra-sequence relationships, and multi-head attention increases model capacity through parallel attention operations.

For production code, use PyTorch’s built-in nn.MultiheadAttention—it’s optimized and battle-tested. But understanding the implementation details helps you debug issues, customize architectures, and stay current as the field evolves.

Next steps: explore causal masking for autoregressive models, cross-attention for encoder-decoder architectures, relative positional encodings, and attention optimization techniques like flash attention. The attention mechanism is your foundation for working with transformers, and you’re now equipped to build and modify them confidently.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.