How to Implement a Transformer in PyTorch
The Transformer architecture, introduced in 'Attention is All You Need,' revolutionized sequence modeling by eliminating recurrent connections entirely. Instead of processing sequences step-by-step,...
Key Insights
- The Transformer architecture replaces recurrence with self-attention, enabling parallel processing of sequences and better long-range dependency modeling than RNNs or LSTMs.
- Multi-head attention allows the model to jointly attend to information from different representation subspaces, while positional encodings inject sequence order information that attention mechanisms lack.
- Implementing a Transformer from scratch requires understanding five core components: multi-head attention, feed-forward networks, positional encoding, layer normalization with residual connections, and proper masking strategies.
Architecture Overview
The Transformer architecture, introduced in “Attention is All You Need,” revolutionized sequence modeling by eliminating recurrent connections entirely. Instead of processing sequences step-by-step, Transformers use self-attention to compute representations for all positions simultaneously.
The architecture consists of an encoder stack and decoder stack, each containing multiple identical layers. Encoder layers process the input sequence, while decoder layers generate the output sequence autoregressively. Both rely on the same fundamental building blocks: multi-head self-attention, position-wise feed-forward networks, layer normalization, and residual connections.
Let’s build each component systematically, starting with the attention mechanism itself.
Multi-Head Self-Attention Mechanism
Self-attention computes a weighted sum of value vectors, where weights are determined by the compatibility between query and key vectors. The scaled dot-product attention formula is:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
Multi-head attention runs this process in parallel across multiple representation subspaces, then concatenates and projects the results.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# Q, K, V shape: (batch_size, num_heads, seq_len, d_k)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weights
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Linear projections and reshape for multi-head attention
Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads and apply final linear projection
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
output = self.W_o(attn_output)
return output, attn_weights
The key insight here is the reshaping operation that splits the d_model dimension into num_heads separate attention heads, each operating on d_k dimensions. This allows the model to attend to different aspects of the input simultaneously.
Position-wise Feed-Forward Networks
After attention, each position passes through an identical feed-forward network consisting of two linear transformations with a ReLU activation in between. This adds non-linearity and increases model capacity.
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x shape: (batch_size, seq_len, d_model)
return self.fc2(self.dropout(F.relu(self.fc1(x))))
Typically, d_ff is set to 4 times d_model (e.g., 2048 when d_model is 512). This expansion and contraction pattern helps the model learn complex transformations.
Positional Encoding
Since attention has no inherent notion of sequence order, we must inject positional information. The original Transformer uses sinusoidal functions of different frequencies:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# Create positional encoding matrix
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # Add batch dimension
self.register_buffer('pe', pe)
def forward(self, x):
# x shape: (batch_size, seq_len, d_model)
return x + self.pe[:, :x.size(1), :]
The sinusoidal encoding allows the model to extrapolate to sequence lengths longer than those seen during training, as relative positions can be expressed as linear functions of the encoding.
Building Encoder and Decoder Blocks
Now we assemble the components into complete encoder and decoder layers, adding layer normalization and residual connections.
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection and layer norm
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward with residual connection and layer norm
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
# Masked self-attention
attn_output, _ = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
# Cross-attention to encoder output
attn_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
# Feed-forward
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
The decoder has an additional cross-attention layer that attends to the encoder’s output, allowing it to condition generation on the input sequence.
Complete Transformer Model
Let’s combine everything into the full Transformer architecture:
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
num_heads=8, num_layers=6, d_ff=2048, max_len=5000, dropout=0.1):
super().__init__()
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_len)
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(d_model)
def encode(self, src, src_mask=None):
x = self.dropout(self.positional_encoding(
self.encoder_embedding(src) * self.scale
))
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x
def decode(self, tgt, enc_output, src_mask=None, tgt_mask=None):
x = self.dropout(self.positional_encoding(
self.decoder_embedding(tgt) * self.scale
))
for layer in self.decoder_layers:
x = layer(x, enc_output, src_mask, tgt_mask)
return self.fc_out(x)
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
enc_output = self.encode(src, src_mask)
output = self.decode(tgt, enc_output, src_mask, tgt_mask)
return output
Training and Inference Example
Proper masking is crucial for training. We need padding masks to ignore padding tokens and causal masks to prevent the decoder from attending to future positions:
def create_padding_mask(seq, pad_idx=0):
return (seq != pad_idx).unsqueeze(1).unsqueeze(2)
def create_causal_mask(size):
mask = torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)
return mask
# Example training setup
src_vocab_size = 10000
tgt_vocab_size = 10000
model = Transformer(src_vocab_size, tgt_vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=0)
# Training loop
model.train()
for epoch in range(num_epochs):
for src_batch, tgt_batch in dataloader:
# Create masks
src_mask = create_padding_mask(src_batch)
tgt_mask = create_causal_mask(tgt_batch.size(1)) & create_padding_mask(tgt_batch)
# Forward pass (exclude last token from decoder input)
output = model(src_batch, tgt_batch[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])
# Calculate loss (compare with shifted target)
loss = criterion(output.reshape(-1, tgt_vocab_size), tgt_batch[:, 1:].reshape(-1))
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Greedy decoding for inference
def generate(model, src, max_len=50, start_token=1, end_token=2):
model.eval()
src_mask = create_padding_mask(src)
enc_output = model.encode(src, src_mask)
tgt = torch.LongTensor([[start_token]])
for _ in range(max_len):
tgt_mask = create_causal_mask(tgt.size(1))
output = model.decode(tgt, enc_output, src_mask, tgt_mask)
next_token = output[:, -1, :].argmax(dim=-1).unsqueeze(0)
if next_token.item() == end_token:
break
tgt = torch.cat([tgt, next_token], dim=1)
return tgt
This implementation provides a complete, working Transformer. For production use, add learning rate scheduling (warmup then decay), label smoothing, and beam search decoding. The architecture remains fundamentally the same whether you’re building a translation system, language model, or any other sequence-to-sequence task.