How to Implement Seq2Seq Models in TensorFlow
Sequence-to-sequence (seq2seq) models revolutionized how we approach problems where both input and output are sequences of variable length. Unlike traditional fixed-size input-output models, seq2seq...
Key Insights
- Seq2seq models use an encoder-decoder architecture where the encoder compresses input sequences into context vectors and the decoder generates output sequences, making them ideal for translation, summarization, and conversational AI
- Attention mechanisms dramatically improve seq2seq performance by allowing the decoder to focus on relevant parts of the input sequence rather than relying solely on a fixed context vector
- Teacher forcing during training (feeding ground truth as decoder input) accelerates convergence, but inference requires autoregressive generation where the model feeds its own predictions back as input
Introduction to Sequence-to-Sequence Models
Sequence-to-sequence (seq2seq) models revolutionized how we approach problems where both input and output are sequences of variable length. Unlike traditional fixed-size input-output models, seq2seq architectures can handle scenarios where a sentence in English needs to become a sentence in French, or where a document must be condensed into a summary.
The core innovation is the encoder-decoder paradigm: the encoder processes the entire input sequence and compresses it into a fixed-size representation (context vector), which the decoder then uses to generate the output sequence one token at a time. This architecture powers machine translation systems, chatbots, text summarization tools, and even code generation applications.
In this article, you’ll build a complete seq2seq model in TensorFlow, starting with basic components and progressing to a production-ready translation system with attention mechanisms.
Understanding the Encoder-Decoder Architecture
The encoder-decoder architecture consists of two recurrent neural networks working in tandem. The encoder reads the input sequence token by token, updating its hidden state at each step. The final hidden state becomes the context vector—a dense representation capturing the input’s semantic meaning.
The decoder is another RNN that generates the output sequence. It initializes with the encoder’s context vector and produces one token at a time, using its own hidden state and previously generated tokens to predict the next token.
Here’s a conceptual implementation showing the data flow:
import tensorflow as tf
# Simplified encoder-decoder flow
def encode_decode_flow(input_sequence, target_sequence):
# Encoder: input_sequence -> context_vector
encoder_outputs, encoder_state_h, encoder_state_c = encoder(input_sequence)
# Decoder: context_vector + target_sequence -> predictions
decoder_outputs = decoder(target_sequence,
initial_state=[encoder_state_h, encoder_state_c])
return decoder_outputs
The hidden states are crucial. For LSTM cells, you maintain both a hidden state (h) and a cell state (c), which together carry information through the sequence. These states allow the network to maintain long-term dependencies.
Building the Encoder with LSTM Layers
Let’s implement a robust encoder using TensorFlow’s Keras API. The encoder consists of an embedding layer to convert token indices into dense vectors, followed by LSTM layers that process the sequence.
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):
super(Encoder, self).__init__()
self.batch_size = batch_size
self.enc_units = enc_units
# Embedding layer converts token indices to dense vectors
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
# LSTM layer with return_state=True to capture final states
# return_sequences=True gives us outputs at each timestep for attention
self.lstm = tf.keras.layers.LSTM(
self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform'
)
def call(self, x, hidden):
# x shape: (batch_size, max_length)
x = self.embedding(x) # (batch_size, max_length, embedding_dim)
output, state_h, state_c = self.lstm(x, initial_state=hidden)
# output shape: (batch_size, max_length, enc_units)
# state_h, state_c shape: (batch_size, enc_units)
return output, state_h, state_c
def initialize_hidden_state(self):
return [tf.zeros((self.batch_size, self.enc_units)),
tf.zeros((self.batch_size, self.enc_units))]
The return_sequences=True parameter is critical—we need the encoder outputs at every timestep for the attention mechanism, not just the final output.
Implementing the Decoder with Attention Mechanism
Basic seq2seq models struggle with long sequences because the entire input must be compressed into a fixed-size context vector. Attention mechanisms solve this by allowing the decoder to “look back” at all encoder outputs, focusing on relevant parts of the input for each output token.
Here’s a Bahdanau attention implementation:
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self, units):
super(BahdanauAttention, self).__init__()
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self, query, values):
# query: decoder hidden state (batch_size, hidden_size)
# values: encoder outputs (batch_size, max_length, hidden_size)
# Expand query to (batch_size, 1, hidden_size)
query_with_time_axis = tf.expand_dims(query, 1)
# Score shape: (batch_size, max_length, 1)
score = self.V(tf.nn.tanh(
self.W1(query_with_time_axis) + self.W2(values)))
# Attention weights shape: (batch_size, max_length, 1)
attention_weights = tf.nn.softmax(score, axis=1)
# Context vector shape: (batch_size, hidden_size)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_size):
super(Decoder, self).__init__()
self.batch_size = batch_size
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.lstm = tf.keras.layers.LSTM(
self.dec_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform'
)
self.fc = tf.keras.layers.Dense(vocab_size)
self.attention = BahdanauAttention(self.dec_units)
def call(self, x, hidden, enc_output):
# x shape: (batch_size, 1)
# hidden: [state_h, state_c] each (batch_size, dec_units)
# enc_output: (batch_size, max_length, enc_units)
context_vector, attention_weights = self.attention(hidden[0], enc_output)
# x after embedding: (batch_size, 1, embedding_dim)
x = self.embedding(x)
# Concatenate context vector with embedded input
# (batch_size, 1, embedding_dim + enc_units)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
output, state_h, state_c = self.lstm(x, initial_state=hidden)
# output shape: (batch_size, 1, dec_units)
output = tf.reshape(output, (-1, output.shape[2]))
# Final prediction: (batch_size, vocab_size)
x = self.fc(output)
return x, [state_h, state_c], attention_weights
Training the Model with Teacher Forcing
Teacher forcing is a training technique where we feed the actual target token as input to the decoder at each step, rather than the decoder’s own prediction. This accelerates training by preventing error accumulation.
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
def loss_function(real, pred):
# Mask padding tokens (typically 0)
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_mean(loss_)
@tf.function
def train_step(inp, targ, enc_hidden, encoder, decoder):
loss = 0
with tf.GradientTape() as tape:
enc_output, enc_h, enc_c = encoder(inp, enc_hidden)
dec_hidden = [enc_h, enc_c]
# Teacher forcing: feeding target as next input
dec_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1)
# Iterate through target sequence
for t in range(1, targ.shape[1]):
predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
loss += loss_function(targ[:, t], predictions)
# Teacher forcing: use actual target as next input
dec_input = tf.expand_dims(targ[:, t], 1)
batch_loss = (loss / int(targ.shape[1]))
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return batch_loss
Inference and Beam Search Decoding
During inference, we don’t have target sequences. The decoder must generate tokens autoregressively, feeding its own predictions back as input. Greedy decoding selects the highest probability token at each step, while beam search maintains multiple hypotheses for better quality.
def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length):
attention_plot = np.zeros((max_length, max_length))
# Preprocess input sentence
inputs = [inp_lang.word_index.get(i, inp_lang.word_index['<unk>'])
for i in sentence.split(' ')]
inputs = tf.keras.preprocessing.sequence.pad_sequences(
[inputs], maxlen=max_length, padding='post')
inputs = tf.convert_to_tensor(inputs)
result = ''
hidden = [tf.zeros((1, units)), tf.zeros((1, units))]
enc_out, enc_h, enc_c = encoder(inputs, hidden)
dec_hidden = [enc_h, enc_c]
dec_input = tf.expand_dims([targ_lang.word_index['<start>']], 0)
for t in range(max_length):
predictions, dec_hidden, attention_weights = decoder(
dec_input, dec_hidden, enc_out)
attention_weights = tf.reshape(attention_weights, (-1, ))
attention_plot[t] = attention_weights.numpy()
# Greedy decoding: select highest probability token
predicted_id = tf.argmax(predictions[0]).numpy()
result += targ_lang.index_word.get(predicted_id, '') + ' '
if targ_lang.index_word.get(predicted_id) == '<end>':
return result, sentence, attention_plot
# Feed prediction back as input
dec_input = tf.expand_dims([predicted_id], 0)
return result, sentence, attention_plot
def beam_search_decoder(encoder, decoder, inp_lang, targ_lang,
sentence, beam_width=3, max_length=50):
# Encode input
inputs = [inp_lang.word_index.get(i, inp_lang.word_index['<unk>'])
for i in sentence.split(' ')]
inputs = tf.keras.preprocessing.sequence.pad_sequences(
[inputs], maxlen=max_length, padding='post')
inputs = tf.convert_to_tensor(inputs)
hidden = [tf.zeros((1, units)), tf.zeros((1, units))]
enc_out, enc_h, enc_c = encoder(inputs, hidden)
# Initialize beam with start token
sequences = [[([targ_lang.word_index['<start>']], [enc_h, enc_c], 0.0)]]
for _ in range(max_length):
all_candidates = []
for seq in sequences[-1]:
tokens, dec_hidden, score = seq
if tokens[-1] == targ_lang.word_index['<end>']:
all_candidates.append(seq)
continue
dec_input = tf.expand_dims([tokens[-1]], 0)
predictions, new_hidden, _ = decoder(dec_input, dec_hidden, enc_out)
# Get top k predictions
top_k = tf.nn.top_k(predictions[0], k=beam_width)
for i in range(beam_width):
candidate = (
tokens + [top_k.indices[i].numpy()],
new_hidden,
score - tf.math.log(top_k.values[i]).numpy()
)
all_candidates.append(candidate)
# Select top beam_width sequences
ordered = sorted(all_candidates, key=lambda x: x[2])
sequences.append(ordered[:beam_width])
# Return best sequence
best_seq = sequences[-1][0][0]
return ' '.join([targ_lang.index_word.get(i, '') for i in best_seq[1:]])
Complete Example: English-to-French Translation
Here’s a minimal end-to-end implementation you can run:
import tensorflow as tf
import numpy as np
# Hyperparameters
BATCH_SIZE = 64
embedding_dim = 256
units = 1024
EPOCHS = 10
# Sample data (in practice, use real datasets like WMT)
eng_sentences = ["I love machine learning", "This is a test"]
fra_sentences = ["J'aime l'apprentissage automatique", "Ceci est un test"]
# Tokenization (simplified)
inp_lang = tf.keras.preprocessing.text.Tokenizer(filters='')
inp_lang.fit_on_texts(eng_sentences)
targ_lang = tf.keras.preprocessing.text.Tokenizer(filters='')
targ_lang.fit_on_texts(['<start> ' + s + ' <end>' for s in fra_sentences])
# Training loop
encoder = Encoder(len(inp_lang.word_index)+1, embedding_dim, units, BATCH_SIZE)
decoder = Decoder(len(targ_lang.word_index)+1, embedding_dim, units, BATCH_SIZE)
for epoch in range(EPOCHS):
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
# Process batches (simplified)
for batch, (inp, targ) in enumerate(dataset):
batch_loss = train_step(inp, targ, enc_hidden, encoder, decoder)
total_loss += batch_loss
print(f'Epoch {epoch+1} Loss {total_loss:.4f}')
# Inference
result, _, _ = evaluate("I love machine learning", encoder, decoder,
inp_lang, targ_lang, max_length=50)
print(f"Translation: {result}")
This implementation provides a solid foundation for seq2seq models. For production systems, add preprocessing pipelines, proper dataset handling with tf.data, checkpoint saving, and evaluation metrics like BLEU scores. The attention mechanism significantly improves translation quality, especially for longer sentences where fixed context vectors become a bottleneck.