How to Implement GPT in PyTorch

GPT (Generative Pre-trained Transformer) is a decoder-only transformer architecture designed for autoregressive language modeling. Unlike BERT or the original Transformer, GPT uses only the decoder...

Key Insights

  • GPT’s decoder-only architecture relies on masked self-attention to predict the next token autoregressively, making it fundamentally different from encoder-decoder transformers like the original Transformer
  • The core innovation is scaled dot-product attention with causal masking—preventing tokens from attending to future positions—which enables parallel training while maintaining autoregressive generation
  • A production-ready GPT implementation requires just five components: multi-head attention, feed-forward blocks, positional embeddings, layer normalization, and residual connections

Understanding GPT Architecture

GPT (Generative Pre-trained Transformer) is a decoder-only transformer architecture designed for autoregressive language modeling. Unlike BERT or the original Transformer, GPT uses only the decoder stack with causal (masked) self-attention to predict the next token based solely on previous tokens.

The architecture consists of stacked transformer blocks, each containing:

  • Multi-head self-attention with causal masking
  • Position-wise feed-forward network
  • Layer normalization (applied before each sub-layer in modern variants)
  • Residual connections around each sub-layer

We’ll implement GPT-2 style architecture with pre-normalization, which has become the standard approach for better training stability.

Multi-Head Self-Attention

Self-attention is the heart of GPT. It allows each token to attend to all previous tokens in the sequence, computing weighted representations based on relevance.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Single matrix for Q, K, V projections (more efficient)
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Causal mask buffer (registered as buffer, not parameter)
        self.register_buffer(
            "causal_mask",
            torch.tril(torch.ones(1024, 1024)).view(1, 1, 1024, 1024)
        )
    
    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        
        # Project to Q, K, V
        qkv = self.qkv_proj(x)  # (B, T, 3*d_model)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, n_heads, T, d_k)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply causal mask (prevent attending to future tokens)
        scores = scores.masked_fill(
            self.causal_mask[:, :, :seq_len, :seq_len] == 0,
            float('-inf')
        )
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)  # (B, n_heads, T, d_k)
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_len, d_model)
        
        return self.out_proj(attn_output)

The causal mask is critical—it ensures each position can only attend to earlier positions, maintaining the autoregressive property during training. We use masked_fill with -inf so these positions get zero attention weight after softmax.

Transformer Block and Positional Encoding

Each transformer block applies attention, then a feed-forward network, with layer normalization and residual connections. Modern GPT uses pre-normalization (LayerNorm before each sub-layer) rather than post-normalization.

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),  # GPT-2 uses GELU activation
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, d_ff, dropout)
    
    def forward(self, x):
        # Pre-normalization with residual connections
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024):
        super().__init__()
        # Learnable positional embeddings (GPT-2 style)
        self.pos_embedding = nn.Embedding(max_len, d_model)
    
    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        return self.pos_embedding(positions)

GPT-2 uses learnable positional embeddings rather than the sinusoidal encodings from the original Transformer paper. This is simpler and works just as well in practice.

Complete GPT Model

Now we assemble everything into the full GPT model:

class GPT(nn.Module):
    def __init__(self, vocab_size, d_model=768, n_layers=12, n_heads=12, 
                 d_ff=3072, max_len=1024, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)  # Final layer norm
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Tie weights between token embeddings and output projection
        self.lm_head.weight = self.token_embedding.weight
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, idx, targets=None):
        # idx: (batch_size, seq_len)
        tok_emb = self.token_embedding(idx)
        pos_emb = self.pos_encoding(idx)
        x = tok_emb + pos_emb
        
        for block in self.transformer_blocks:
            x = block(x)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )
        
        return logits, loss

Weight tying between the token embedding and output projection is a common technique that reduces parameters and often improves performance. The initialization scheme follows GPT-2’s approach with small standard deviation.

Training and Text Generation

Here’s a minimal training setup with autoregressive generation:

def train_step(model, optimizer, data, targets):
    model.train()
    optimizer.zero_grad()
    
    logits, loss = model(data, targets)
    loss.backward()
    
    # Gradient clipping for stability
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    return loss.item()

@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    model.eval()
    
    for _ in range(max_new_tokens):
        # Crop context if needed
        idx_cond = idx if idx.size(1) <= 1024 else idx[:, -1024:]
        
        # Get predictions
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature
        
        # Optional top-k filtering
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')
        
        # Sample from distribution
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        
        # Append to sequence
        idx = torch.cat([idx, idx_next], dim=1)
    
    return idx

Temperature controls randomness (lower = more deterministic), while top-k sampling restricts choices to the k most likely tokens, improving generation quality.

Testing on Real Data

Let’s train on a character-level dataset:

# Simple character-level tokenization
text = open('shakespeare.txt', 'r').read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Prepare data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

# Initialize model (smaller for demo)
model = GPT(
    vocab_size=vocab_size,
    d_model=256,
    n_layers=6,
    n_heads=8,
    d_ff=1024,
    max_len=256
).cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# Training loop
block_size = 128
batch_size = 32

for step in range(5000):
    # Sample batch
    ix = torch.randint(len(train_data) - block_size, (batch_size,))
    x = torch.stack([train_data[i:i+block_size] for i in ix]).cuda()
    y = torch.stack([train_data[i+1:i+block_size+1] for i in ix]).cuda()
    
    loss = train_step(model, optimizer, x, y)
    
    if step % 500 == 0:
        print(f"Step {step}, Loss: {loss:.4f}")

# Generate sample
context = torch.tensor([encode("To be or not to be")], dtype=torch.long).cuda()
generated = generate(model, context, max_new_tokens=200, temperature=0.8, top_k=40)
print(decode(generated[0].tolist()))

After 5000 steps, you’ll see the model generating Shakespeare-like text with proper structure, though character-level models produce more spelling errors than subword tokenizers.

Practical Considerations

This implementation is educational but production systems need additional features: gradient checkpointing for memory efficiency, mixed precision training with automatic mixed precision (AMP), Flash Attention for faster computation, and proper data loading with distributed training support.

The architecture scales predictably—GPT-3 uses the same components with d_model=12288, n_layers=96, and n_heads=96. The core implementation remains identical; only hyperparameters change.

For real applications, use the Hugging Face Transformers library which provides optimized implementations, pretrained weights, and tokenizers. But understanding this from-scratch implementation helps you debug issues, customize architectures, and truly understand what’s happening under the hood.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.