How to Implement Seq2Seq Models in PyTorch
Sequence-to-sequence (seq2seq) models solve a fundamental problem in machine learning: mapping variable-length input sequences to variable-length output sequences. Unlike traditional neural networks...
Key Insights
- Seq2seq models use an encoder-decoder architecture to transform variable-length input sequences into variable-length outputs, making them ideal for translation, summarization, and dialogue systems
- Attention mechanisms are essential for production seq2seq models—they solve the information bottleneck problem by letting decoders focus on relevant input positions rather than compressing everything into a fixed context vector
- Teacher forcing during training (feeding ground truth instead of predictions) accelerates convergence but requires careful scheduling to prevent exposure bias at inference time
Introduction to Sequence-to-Sequence Architecture
Sequence-to-sequence (seq2seq) models solve a fundamental problem in machine learning: mapping variable-length input sequences to variable-length output sequences. Unlike traditional neural networks that require fixed-size inputs and outputs, seq2seq architectures handle sequences of arbitrary length.
The core applications are everywhere: machine translation (English to French), text summarization (long article to summary), chatbots (user message to response), and even code generation. The architecture consists of two main components: an encoder that processes the input sequence into a fixed-size context representation, and a decoder that generates the output sequence from this context.
The basic seq2seq model has a critical limitation—compressing all input information into a single fixed-size vector creates an information bottleneck. This is where attention mechanisms become crucial, allowing the decoder to selectively focus on different parts of the input at each generation step.
Building the Encoder
The encoder’s job is straightforward: read the input sequence and produce a meaningful representation. We use recurrent layers (LSTM or GRU) to process sequences while maintaining temporal dependencies. An embedding layer first converts discrete tokens into continuous vectors.
Here’s a production-ready encoder implementation:
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, input_vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1):
super(Encoder, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.embedding = nn.Embedding(input_vocab_size, embedding_dim)
self.lstm = nn.LSTM(
embedding_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0,
batch_first=True
)
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_lengths):
# src shape: (batch_size, seq_len)
embedded = self.dropout(self.embedding(src))
# embedded shape: (batch_size, seq_len, embedding_dim)
# Pack padded sequences for efficiency
packed = nn.utils.rnn.pack_padded_sequence(
embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=False
)
outputs, (hidden, cell) = self.lstm(packed)
# Unpack sequences
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
# outputs shape: (batch_size, seq_len, hidden_dim)
# hidden shape: (num_layers, batch_size, hidden_dim)
return outputs, hidden, cell
The key details: we use pack_padded_sequence to avoid wasting computation on padding tokens, and we return both the full output sequence (needed for attention) and the final hidden states (the context vector).
Building the Decoder
The decoder generates output sequences one token at a time. At each step, it receives the previous token (or the ground truth during teacher forcing), the previous hidden state, and optionally an attention-weighted context from the encoder outputs.
class Decoder(nn.Module):
def __init__(self, output_vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1):
super(Decoder, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.output_vocab_size = output_vocab_size
self.embedding = nn.Embedding(output_vocab_size, embedding_dim)
self.lstm = nn.LSTM(
embedding_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout if num_layers > 1 else 0,
batch_first=True
)
self.fc_out = nn.Linear(hidden_dim, output_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, input_token, hidden, cell):
# input_token shape: (batch_size, 1)
embedded = self.dropout(self.embedding(input_token))
# embedded shape: (batch_size, 1, embedding_dim)
output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
# output shape: (batch_size, 1, hidden_dim)
prediction = self.fc_out(output.squeeze(1))
# prediction shape: (batch_size, output_vocab_size)
return prediction, hidden, cell
Teacher forcing is critical here. During training, instead of feeding the decoder’s own predictions as input for the next step, we feed the ground truth tokens. This dramatically speeds up training but can create exposure bias—the model never learns to recover from its own mistakes.
Implementing the Attention Mechanism
Attention mechanisms are non-negotiable for real-world seq2seq models. They allow the decoder to create a different context vector at each decoding step by computing weighted sums over encoder outputs.
Here’s a Bahdanau-style attention implementation:
class BahdanauAttention(nn.Module):
def __init__(self, hidden_dim):
super(BahdanauAttention, self).__init__()
self.Wa = nn.Linear(hidden_dim, hidden_dim)
self.Ua = nn.Linear(hidden_dim, hidden_dim)
self.Va = nn.Linear(hidden_dim, 1)
def forward(self, decoder_hidden, encoder_outputs, mask=None):
# decoder_hidden shape: (batch_size, hidden_dim)
# encoder_outputs shape: (batch_size, src_len, hidden_dim)
# Score calculation
decoder_hidden = decoder_hidden.unsqueeze(1) # (batch_size, 1, hidden_dim)
score = self.Va(torch.tanh(
self.Wa(decoder_hidden) + self.Ua(encoder_outputs)
)) # (batch_size, src_len, 1)
score = score.squeeze(2) # (batch_size, src_len)
# Apply mask for padded positions
if mask is not None:
score = score.masked_fill(mask == 0, -1e10)
attention_weights = torch.softmax(score, dim=1) # (batch_size, src_len)
# Context vector: weighted sum of encoder outputs
context = torch.bmm(
attention_weights.unsqueeze(1),
encoder_outputs
).squeeze(1) # (batch_size, hidden_dim)
return context, attention_weights
class AttentionDecoder(nn.Module):
def __init__(self, output_vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1):
super(AttentionDecoder, self).__init__()
self.embedding = nn.Embedding(output_vocab_size, embedding_dim)
self.attention = BahdanauAttention(hidden_dim)
self.lstm = nn.LSTM(
embedding_dim + hidden_dim, # Concatenate embedding with context
hidden_dim,
num_layers=num_layers,
batch_first=True
)
self.fc_out = nn.Linear(hidden_dim, output_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, input_token, hidden, cell, encoder_outputs, src_mask=None):
embedded = self.dropout(self.embedding(input_token))
# Get attention context using top layer hidden state
context, attention_weights = self.attention(
hidden[-1], encoder_outputs, src_mask
)
# Concatenate embedding with context
lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)
output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
prediction = self.fc_out(output.squeeze(1))
return prediction, hidden, cell, attention_weights
Training the Complete Model
Training seq2seq models requires careful orchestration of the encoder, decoder, and loss computation. Use cross-entropy loss and implement teacher forcing with a scheduled ratio.
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src, src_lengths, trg, teacher_forcing_ratio=0.5):
batch_size = src.shape[0]
trg_len = trg.shape[1]
trg_vocab_size = self.decoder.output_vocab_size
outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(src.device)
encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
# First input is <SOS> token
input_token = trg[:, 0].unsqueeze(1)
for t in range(1, trg_len):
prediction, hidden, cell, _ = self.decoder(
input_token, hidden, cell, encoder_outputs
)
outputs[:, t] = prediction
# Teacher forcing decision
use_teacher_forcing = torch.rand(1).item() < teacher_forcing_ratio
top1 = prediction.argmax(1).unsqueeze(1)
input_token = trg[:, t].unsqueeze(1) if use_teacher_forcing else top1
return outputs
def train_epoch(model, dataloader, optimizer, criterion, device, teacher_forcing_ratio=0.5):
model.train()
epoch_loss = 0
for batch in dataloader:
src, src_lengths, trg = batch
src, trg = src.to(device), trg.to(device)
optimizer.zero_grad()
output = model(src, src_lengths, trg, teacher_forcing_ratio)
# Reshape for loss calculation
output_dim = output.shape[-1]
output = output[:, 1:].reshape(-1, output_dim)
trg = trg[:, 1:].reshape(-1)
loss = criterion(output, trg)
loss.backward()
# Gradient clipping prevents exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
Inference and Evaluation
At inference time, we don’t have target sequences, so we use greedy decoding or beam search. Greedy decoding is simpler and faster but can miss better sequences.
def greedy_decode(model, src, src_length, max_len=50, sos_idx=1, eos_idx=2):
model.eval()
with torch.no_grad():
encoder_outputs, hidden, cell = model.encoder(src, src_length)
input_token = torch.tensor([[sos_idx]]).to(src.device)
decoded_tokens = []
for _ in range(max_len):
prediction, hidden, cell, _ = model.decoder(
input_token, hidden, cell, encoder_outputs
)
predicted_token = prediction.argmax(1).item()
decoded_tokens.append(predicted_token)
if predicted_token == eos_idx:
break
input_token = torch.tensor([[predicted_token]]).to(src.device)
return decoded_tokens
For evaluation, use BLEU scores for translation tasks or task-specific metrics. BLEU measures n-gram overlap between predictions and references.
Best Practices and Optimization Tips
Batch padding efficiently: Use collate functions to pad sequences dynamically within batches rather than padding the entire dataset to the maximum length.
def collate_fn(batch):
src_batch, trg_batch = zip(*batch)
src_lengths = torch.tensor([len(s) for s in src_batch])
# Pad sequences
src_padded = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
trg_padded = nn.utils.rnn.pad_sequence(trg_batch, batch_first=True, padding_value=0)
return src_padded, src_lengths, trg_padded
# Use with DataLoader
dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn, shuffle=True)
Always clip gradients to prevent exploding gradients in recurrent networks. A max norm of 1.0 works well in most cases.
Schedule teacher forcing ratio from high (0.8-1.0) to low (0.3-0.5) over training to reduce exposure bias. Start with high ratios for stable early training, then gradually expose the model to its own predictions.
Use packed sequences to avoid wasting computation on padding tokens—this can speed up training by 2-3x on datasets with variable-length sequences.
Initialize embeddings with pretrained vectors (Word2Vec, GloVe) when possible, especially for low-resource scenarios.
Seq2seq models remain foundational despite transformer architectures dominating recent benchmarks. Understanding their mechanics—encoder-decoder structure, attention mechanisms, and training dynamics—provides essential intuition for modern sequence modeling approaches.