How to Implement Seq2Seq Models in PyTorch

Sequence-to-sequence (seq2seq) models solve a fundamental problem in machine learning: mapping variable-length input sequences to variable-length output sequences. Unlike traditional neural networks...

Key Insights

  • Seq2seq models use an encoder-decoder architecture to transform variable-length input sequences into variable-length outputs, making them ideal for translation, summarization, and dialogue systems
  • Attention mechanisms are essential for production seq2seq models—they solve the information bottleneck problem by letting decoders focus on relevant input positions rather than compressing everything into a fixed context vector
  • Teacher forcing during training (feeding ground truth instead of predictions) accelerates convergence but requires careful scheduling to prevent exposure bias at inference time

Introduction to Sequence-to-Sequence Architecture

Sequence-to-sequence (seq2seq) models solve a fundamental problem in machine learning: mapping variable-length input sequences to variable-length output sequences. Unlike traditional neural networks that require fixed-size inputs and outputs, seq2seq architectures handle sequences of arbitrary length.

The core applications are everywhere: machine translation (English to French), text summarization (long article to summary), chatbots (user message to response), and even code generation. The architecture consists of two main components: an encoder that processes the input sequence into a fixed-size context representation, and a decoder that generates the output sequence from this context.

The basic seq2seq model has a critical limitation—compressing all input information into a single fixed-size vector creates an information bottleneck. This is where attention mechanisms become crucial, allowing the decoder to selectively focus on different parts of the input at each generation step.

Building the Encoder

The encoder’s job is straightforward: read the input sequence and produce a meaningful representation. We use recurrent layers (LSTM or GRU) to process sequences while maintaining temporal dependencies. An embedding layer first converts discrete tokens into continuous vectors.

Here’s a production-ready encoder implementation:

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(input_vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim, 
            hidden_dim, 
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_lengths):
        # src shape: (batch_size, seq_len)
        embedded = self.dropout(self.embedding(src))
        # embedded shape: (batch_size, seq_len, embedding_dim)
        
        # Pack padded sequences for efficiency
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        outputs, (hidden, cell) = self.lstm(packed)
        
        # Unpack sequences
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        # outputs shape: (batch_size, seq_len, hidden_dim)
        # hidden shape: (num_layers, batch_size, hidden_dim)
        
        return outputs, hidden, cell

The key details: we use pack_padded_sequence to avoid wasting computation on padding tokens, and we return both the full output sequence (needed for attention) and the final hidden states (the context vector).

Building the Decoder

The decoder generates output sequences one token at a time. At each step, it receives the previous token (or the ground truth during teacher forcing), the previous hidden state, and optionally an attention-weighted context from the encoder outputs.

class Decoder(nn.Module):
    def __init__(self, output_vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_vocab_size = output_vocab_size
        
        self.embedding = nn.Embedding(output_vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        self.fc_out = nn.Linear(hidden_dim, output_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input_token, hidden, cell):
        # input_token shape: (batch_size, 1)
        embedded = self.dropout(self.embedding(input_token))
        # embedded shape: (batch_size, 1, embedding_dim)
        
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        # output shape: (batch_size, 1, hidden_dim)
        
        prediction = self.fc_out(output.squeeze(1))
        # prediction shape: (batch_size, output_vocab_size)
        
        return prediction, hidden, cell

Teacher forcing is critical here. During training, instead of feeding the decoder’s own predictions as input for the next step, we feed the ground truth tokens. This dramatically speeds up training but can create exposure bias—the model never learns to recover from its own mistakes.

Implementing the Attention Mechanism

Attention mechanisms are non-negotiable for real-world seq2seq models. They allow the decoder to create a different context vector at each decoding step by computing weighted sums over encoder outputs.

Here’s a Bahdanau-style attention implementation:

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_dim, hidden_dim)
        self.Ua = nn.Linear(hidden_dim, hidden_dim)
        self.Va = nn.Linear(hidden_dim, 1)
        
    def forward(self, decoder_hidden, encoder_outputs, mask=None):
        # decoder_hidden shape: (batch_size, hidden_dim)
        # encoder_outputs shape: (batch_size, src_len, hidden_dim)
        
        # Score calculation
        decoder_hidden = decoder_hidden.unsqueeze(1)  # (batch_size, 1, hidden_dim)
        score = self.Va(torch.tanh(
            self.Wa(decoder_hidden) + self.Ua(encoder_outputs)
        ))  # (batch_size, src_len, 1)
        
        score = score.squeeze(2)  # (batch_size, src_len)
        
        # Apply mask for padded positions
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e10)
        
        attention_weights = torch.softmax(score, dim=1)  # (batch_size, src_len)
        
        # Context vector: weighted sum of encoder outputs
        context = torch.bmm(
            attention_weights.unsqueeze(1), 
            encoder_outputs
        ).squeeze(1)  # (batch_size, hidden_dim)
        
        return context, attention_weights


class AttentionDecoder(nn.Module):
    def __init__(self, output_vocab_size, embedding_dim, hidden_dim, num_layers=1, dropout=0.1):
        super(AttentionDecoder, self).__init__()
        self.embedding = nn.Embedding(output_vocab_size, embedding_dim)
        self.attention = BahdanauAttention(hidden_dim)
        self.lstm = nn.LSTM(
            embedding_dim + hidden_dim,  # Concatenate embedding with context
            hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc_out = nn.Linear(hidden_dim, output_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input_token, hidden, cell, encoder_outputs, src_mask=None):
        embedded = self.dropout(self.embedding(input_token))
        
        # Get attention context using top layer hidden state
        context, attention_weights = self.attention(
            hidden[-1], encoder_outputs, src_mask
        )
        
        # Concatenate embedding with context
        lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)
        
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        prediction = self.fc_out(output.squeeze(1))
        
        return prediction, hidden, cell, attention_weights

Training the Complete Model

Training seq2seq models requires careful orchestration of the encoder, decoder, and loss computation. Use cross-entropy loss and implement teacher forcing with a scheduled ratio.

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, src_lengths, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_vocab_size
        
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(src.device)
        encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
        
        # First input is <SOS> token
        input_token = trg[:, 0].unsqueeze(1)
        
        for t in range(1, trg_len):
            prediction, hidden, cell, _ = self.decoder(
                input_token, hidden, cell, encoder_outputs
            )
            outputs[:, t] = prediction
            
            # Teacher forcing decision
            use_teacher_forcing = torch.rand(1).item() < teacher_forcing_ratio
            top1 = prediction.argmax(1).unsqueeze(1)
            input_token = trg[:, t].unsqueeze(1) if use_teacher_forcing else top1
            
        return outputs


def train_epoch(model, dataloader, optimizer, criterion, device, teacher_forcing_ratio=0.5):
    model.train()
    epoch_loss = 0
    
    for batch in dataloader:
        src, src_lengths, trg = batch
        src, trg = src.to(device), trg.to(device)
        
        optimizer.zero_grad()
        output = model(src, src_lengths, trg, teacher_forcing_ratio)
        
        # Reshape for loss calculation
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)
        
        loss = criterion(output, trg)
        loss.backward()
        
        # Gradient clipping prevents exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss / len(dataloader)

Inference and Evaluation

At inference time, we don’t have target sequences, so we use greedy decoding or beam search. Greedy decoding is simpler and faster but can miss better sequences.

def greedy_decode(model, src, src_length, max_len=50, sos_idx=1, eos_idx=2):
    model.eval()
    with torch.no_grad():
        encoder_outputs, hidden, cell = model.encoder(src, src_length)
        
        input_token = torch.tensor([[sos_idx]]).to(src.device)
        decoded_tokens = []
        
        for _ in range(max_len):
            prediction, hidden, cell, _ = model.decoder(
                input_token, hidden, cell, encoder_outputs
            )
            
            predicted_token = prediction.argmax(1).item()
            decoded_tokens.append(predicted_token)
            
            if predicted_token == eos_idx:
                break
                
            input_token = torch.tensor([[predicted_token]]).to(src.device)
            
    return decoded_tokens

For evaluation, use BLEU scores for translation tasks or task-specific metrics. BLEU measures n-gram overlap between predictions and references.

Best Practices and Optimization Tips

Batch padding efficiently: Use collate functions to pad sequences dynamically within batches rather than padding the entire dataset to the maximum length.

def collate_fn(batch):
    src_batch, trg_batch = zip(*batch)
    src_lengths = torch.tensor([len(s) for s in src_batch])
    
    # Pad sequences
    src_padded = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
    trg_padded = nn.utils.rnn.pad_sequence(trg_batch, batch_first=True, padding_value=0)
    
    return src_padded, src_lengths, trg_padded

# Use with DataLoader
dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn, shuffle=True)

Always clip gradients to prevent exploding gradients in recurrent networks. A max norm of 1.0 works well in most cases.

Schedule teacher forcing ratio from high (0.8-1.0) to low (0.3-0.5) over training to reduce exposure bias. Start with high ratios for stable early training, then gradually expose the model to its own predictions.

Use packed sequences to avoid wasting computation on padding tokens—this can speed up training by 2-3x on datasets with variable-length sequences.

Initialize embeddings with pretrained vectors (Word2Vec, GloVe) when possible, especially for low-resource scenarios.

Seq2seq models remain foundational despite transformer architectures dominating recent benchmarks. Understanding their mechanics—encoder-decoder structure, attention mechanisms, and training dynamics—provides essential intuition for modern sequence modeling approaches.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.