How to Implement a Transformer in TensorFlow
The transformer architecture, introduced in 'Attention is All You Need,' fundamentally changed how we approach sequence modeling. Unlike RNNs and LSTMs that process sequences sequentially,...
Key Insights
- Transformers replace recurrence with self-attention, enabling parallel processing of sequences and capturing long-range dependencies more effectively than RNNs
- The core innovation is multi-head attention, which allows the model to jointly attend to information from different representation subspaces at different positions
- Implementing a transformer from scratch requires five key components: positional encoding, scaled dot-product attention, multi-head attention, feed-forward networks, and layer normalization with residual connections
Introduction to Transformer Architecture
The transformer architecture, introduced in “Attention is All You Need,” fundamentally changed how we approach sequence modeling. Unlike RNNs and LSTMs that process sequences sequentially, transformers use self-attention mechanisms to process entire sequences in parallel, dramatically improving training efficiency and model performance.
At its core, a transformer consists of an encoder-decoder structure. The encoder maps input sequences to continuous representations, while the decoder generates output sequences from these representations. Each encoder and decoder layer contains two main sublayers: multi-head self-attention and position-wise feed-forward networks, with residual connections and layer normalization around each sublayer.
The attention mechanism allows the model to weigh the importance of different positions in the input when producing each output element. This eliminates the vanishing gradient problem that plagued RNNs and enables the model to capture dependencies regardless of their distance in the sequence.
Positional Encoding Implementation
Transformers have no inherent notion of sequence order since they process all positions simultaneously. Positional encodings inject information about token positions into the input embeddings, allowing the model to leverage sequential structure.
The original paper uses sinusoidal functions with different frequencies for each dimension:
import tensorflow as tf
import numpy as np
def get_positional_encoding(seq_len, d_model):
"""
Generate sinusoidal positional encodings.
Args:
seq_len: Maximum sequence length
d_model: Embedding dimension
Returns:
Positional encoding matrix of shape (seq_len, d_model)
"""
position = np.arange(seq_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pos_encoding = np.zeros((seq_len, d_model))
pos_encoding[:, 0::2] = np.sin(position * div_term)
pos_encoding[:, 1::2] = np.cos(position * div_term)
return tf.cast(pos_encoding, dtype=tf.float32)
This encoding has a useful property: for any fixed offset k, PE(pos+k) can be represented as a linear function of PE(pos), allowing the model to easily learn to attend by relative positions.
Building the Multi-Head Attention Layer
Multi-head attention is the heart of the transformer. It performs attention multiple times in parallel with different learned linear projections, allowing the model to jointly attend to information from different representation subspaces.
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0
self.depth = d_model // num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth)."""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask=None):
batch_size = tf.shape(q)[0]
# Linear projections
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
# Split heads
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
# Scaled dot-product attention
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
# Concatenate heads
output = tf.transpose(output, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(output, (batch_size, -1, self.d_model))
# Final linear projection
output = self.dense(concat_attention)
return output, attention_weights
The scaling factor (1/√dk) prevents the dot products from growing too large, which would push the softmax function into regions with extremely small gradients.
Creating Encoder and Decoder Blocks
Each encoder layer contains a multi-head attention sublayer followed by a position-wise feed-forward network. Both sublayers use residual connections and layer normalization.
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'),
tf.keras.layers.Dense(d_model)
])
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
def call(self, x, training, mask=None):
# Multi-head attention
attn_output, _ = self.mha(x, x, x, mask)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output)
# Feed-forward network
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output)
return out2
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model, num_heads)
self.mha2 = MultiHeadAttention(d_model, num_heads)
self.ffn = tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'),
tf.keras.layers.Dense(d_model)
])
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
self.dropout3 = tf.keras.layers.Dropout(dropout_rate)
def call(self, x, enc_output, training, look_ahead_mask=None, padding_mask=None):
# Masked multi-head attention (self-attention)
attn1, _ = self.mha1(x, x, x, look_ahead_mask)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layernorm1(attn1 + x)
# Multi-head attention over encoder output
attn2, _ = self.mha2(enc_output, enc_output, out1, padding_mask)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1)
# Feed-forward network
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2)
return out3
The decoder has an additional cross-attention layer that attends to the encoder’s output, allowing it to focus on relevant parts of the input sequence when generating each output token.
Assembling the Complete Transformer Model
Now we combine all components into the full transformer model:
class Transformer(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff,
input_vocab_size, target_vocab_size,
max_pe_input, max_pe_target, dropout_rate=0.1):
super(Transformer, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding_input = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.embedding_target = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding_input = get_positional_encoding(max_pe_input, d_model)
self.pos_encoding_target = get_positional_encoding(max_pe_target, d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, dropout_rate)
for _ in range(num_layers)]
self.dec_layers = [DecoderLayer(d_model, num_heads, dff, dropout_rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(dropout_rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
def call(self, inputs, training):
inp, tar = inputs
seq_len_inp = tf.shape(inp)[1]
seq_len_tar = tf.shape(tar)[1]
# Encoder
x = self.embedding_input(inp)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding_input[:seq_len_inp, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layers[i](x, training)
# Decoder
dec_output = self.embedding_target(tar)
dec_output *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
dec_output += self.pos_encoding_target[:seq_len_tar, :]
dec_output = self.dropout(dec_output, training=training)
for i in range(self.num_layers):
dec_output = self.dec_layers[i](dec_output, x, training)
final_output = self.final_layer(dec_output)
return final_output
Training Example and Practical Tips
Here’s how to set up training with a custom learning rate schedule as described in the original paper:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()
self.d_model = tf.cast(d_model, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
step = tf.cast(step, tf.float32)
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
# Model configuration
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1
# Initialize model
transformer = Transformer(
num_layers=num_layers,
d_model=d_model,
num_heads=num_heads,
dff=dff,
input_vocab_size=8500,
target_vocab_size=8000,
max_pe_input=1000,
max_pe_target=1000,
dropout_rate=dropout_rate
)
# Custom learning rate and optimizer
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
# Loss function with masking for padding
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def masked_loss(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_sum(loss_) / tf.reduce_sum(mask)
transformer.compile(optimizer=optimizer, loss=masked_loss, metrics=['accuracy'])
Key training tips:
-
Use label smoothing to prevent the model from becoming overconfident. Set
label_smoothing=0.1in your loss function. -
Implement proper masking for padding tokens and future positions in the decoder to prevent information leakage.
-
Start with smaller models (4 layers, d_model=128) for debugging, then scale up. The original paper used 6 layers and d_model=512.
-
Monitor gradient norms and clip if necessary. Transformers can be sensitive to exploding gradients early in training.
-
Use mixed precision training (
tf.keras.mixed_precision) for significant speedups on modern GPUs.
The transformer architecture has become the foundation for modern NLP models like BERT, GPT, and T5. Understanding its implementation provides insight into these more complex systems and enables you to customize architectures for specific tasks.