How to Implement a Transformer in TensorFlow

The transformer architecture, introduced in 'Attention is All You Need,' fundamentally changed how we approach sequence modeling. Unlike RNNs and LSTMs that process sequences sequentially,...

Key Insights

  • Transformers replace recurrence with self-attention, enabling parallel processing of sequences and capturing long-range dependencies more effectively than RNNs
  • The core innovation is multi-head attention, which allows the model to jointly attend to information from different representation subspaces at different positions
  • Implementing a transformer from scratch requires five key components: positional encoding, scaled dot-product attention, multi-head attention, feed-forward networks, and layer normalization with residual connections

Introduction to Transformer Architecture

The transformer architecture, introduced in “Attention is All You Need,” fundamentally changed how we approach sequence modeling. Unlike RNNs and LSTMs that process sequences sequentially, transformers use self-attention mechanisms to process entire sequences in parallel, dramatically improving training efficiency and model performance.

At its core, a transformer consists of an encoder-decoder structure. The encoder maps input sequences to continuous representations, while the decoder generates output sequences from these representations. Each encoder and decoder layer contains two main sublayers: multi-head self-attention and position-wise feed-forward networks, with residual connections and layer normalization around each sublayer.

The attention mechanism allows the model to weigh the importance of different positions in the input when producing each output element. This eliminates the vanishing gradient problem that plagued RNNs and enables the model to capture dependencies regardless of their distance in the sequence.

Positional Encoding Implementation

Transformers have no inherent notion of sequence order since they process all positions simultaneously. Positional encodings inject information about token positions into the input embeddings, allowing the model to leverage sequential structure.

The original paper uses sinusoidal functions with different frequencies for each dimension:

import tensorflow as tf
import numpy as np

def get_positional_encoding(seq_len, d_model):
    """
    Generate sinusoidal positional encodings.
    
    Args:
        seq_len: Maximum sequence length
        d_model: Embedding dimension
    
    Returns:
        Positional encoding matrix of shape (seq_len, d_model)
    """
    position = np.arange(seq_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    
    pos_encoding = np.zeros((seq_len, d_model))
    pos_encoding[:, 0::2] = np.sin(position * div_term)
    pos_encoding[:, 1::2] = np.cos(position * div_term)
    
    return tf.cast(pos_encoding, dtype=tf.float32)

This encoding has a useful property: for any fixed offset k, PE(pos+k) can be represented as a linear function of PE(pos), allowing the model to easily learn to attend by relative positions.

Building the Multi-Head Attention Layer

Multi-head attention is the heart of the transformer. It performs attention multiple times in parallel with different learned linear projections, allowing the model to jointly attend to information from different representation subspaces.

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        
        assert d_model % num_heads == 0
        
        self.depth = d_model // num_heads
        
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        
        self.dense = tf.keras.layers.Dense(d_model)
        
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth)."""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, v, k, q, mask=None):
        batch_size = tf.shape(q)[0]
        
        # Linear projections
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        
        # Split heads
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        # Scaled dot-product attention
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
        
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)
        
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        output = tf.matmul(attention_weights, v)
        
        # Concatenate heads
        output = tf.transpose(output, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(output, (batch_size, -1, self.d_model))
        
        # Final linear projection
        output = self.dense(concat_attention)
        
        return output, attention_weights

The scaling factor (1/√dk) prevents the dot products from growing too large, which would push the softmax function into regions with extremely small gradients.

Creating Encoder and Decoder Blocks

Each encoder layer contains a multi-head attention sublayer followed by a position-wise feed-forward network. Both sublayers use residual connections and layer normalization.

class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        super(EncoderLayer, self).__init__()
        
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(d_model)
        ])
        
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, x, training, mask=None):
        # Multi-head attention
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)
        
        # Feed-forward network
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)
        
        return out2

class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        super(DecoderLayer, self).__init__()
        
        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)
        
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(d_model)
        ])
        
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout3 = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, x, enc_output, training, look_ahead_mask=None, padding_mask=None):
        # Masked multi-head attention (self-attention)
        attn1, _ = self.mha1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)
        
        # Multi-head attention over encoder output
        attn2, _ = self.mha2(enc_output, enc_output, out1, padding_mask)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1)
        
        # Feed-forward network
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2)
        
        return out3

The decoder has an additional cross-attention layer that attends to the encoder’s output, allowing it to focus on relevant parts of the input sequence when generating each output token.

Assembling the Complete Transformer Model

Now we combine all components into the full transformer model:

class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, 
                 input_vocab_size, target_vocab_size, 
                 max_pe_input, max_pe_target, dropout_rate=0.1):
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        self.num_layers = num_layers
        
        self.embedding_input = tf.keras.layers.Embedding(input_vocab_size, d_model)
        self.embedding_target = tf.keras.layers.Embedding(target_vocab_size, d_model)
        
        self.pos_encoding_input = get_positional_encoding(max_pe_input, d_model)
        self.pos_encoding_target = get_positional_encoding(max_pe_target, d_model)
        
        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, dropout_rate) 
                           for _ in range(num_layers)]
        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, dropout_rate) 
                           for _ in range(num_layers)]
        
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.final_layer = tf.keras.layers.Dense(target_vocab_size)
    
    def call(self, inputs, training):
        inp, tar = inputs
        
        seq_len_inp = tf.shape(inp)[1]
        seq_len_tar = tf.shape(tar)[1]
        
        # Encoder
        x = self.embedding_input(inp)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding_input[:seq_len_inp, :]
        x = self.dropout(x, training=training)
        
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training)
        
        # Decoder
        dec_output = self.embedding_target(tar)
        dec_output *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        dec_output += self.pos_encoding_target[:seq_len_tar, :]
        dec_output = self.dropout(dec_output, training=training)
        
        for i in range(self.num_layers):
            dec_output = self.dec_layers[i](dec_output, x, training)
        
        final_output = self.final_layer(dec_output)
        
        return final_output

Training Example and Practical Tips

Here’s how to set up training with a custom learning rate schedule as described in the original paper:

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = warmup_steps
    
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

# Model configuration
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
dropout_rate = 0.1

# Initialize model
transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_vocab_size=8500,
    target_vocab_size=8000,
    max_pe_input=1000,
    max_pe_target=1000,
    dropout_rate=dropout_rate
)

# Custom learning rate and optimizer
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

# Loss function with masking for padding
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def masked_loss(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_sum(loss_) / tf.reduce_sum(mask)

transformer.compile(optimizer=optimizer, loss=masked_loss, metrics=['accuracy'])

Key training tips:

  1. Use label smoothing to prevent the model from becoming overconfident. Set label_smoothing=0.1 in your loss function.

  2. Implement proper masking for padding tokens and future positions in the decoder to prevent information leakage.

  3. Start with smaller models (4 layers, d_model=128) for debugging, then scale up. The original paper used 6 layers and d_model=512.

  4. Monitor gradient norms and clip if necessary. Transformers can be sensitive to exploding gradients early in training.

  5. Use mixed precision training (tf.keras.mixed_precision) for significant speedups on modern GPUs.

The transformer architecture has become the foundation for modern NLP models like BERT, GPT, and T5. Understanding its implementation provides insight into these more complex systems and enables you to customize architectures for specific tasks.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.