How to Implement GPT in PyTorch
GPT (Generative Pre-trained Transformer) is a decoder-only transformer architecture designed for autoregressive language modeling. Unlike BERT or the original Transformer, GPT uses only the decoder...
Key Insights
- GPT’s decoder-only architecture relies on masked self-attention to predict the next token autoregressively, making it fundamentally different from encoder-decoder transformers like the original Transformer
- The core innovation is scaled dot-product attention with causal masking—preventing tokens from attending to future positions—which enables parallel training while maintaining autoregressive generation
- A production-ready GPT implementation requires just five components: multi-head attention, feed-forward blocks, positional embeddings, layer normalization, and residual connections
Understanding GPT Architecture
GPT (Generative Pre-trained Transformer) is a decoder-only transformer architecture designed for autoregressive language modeling. Unlike BERT or the original Transformer, GPT uses only the decoder stack with causal (masked) self-attention to predict the next token based solely on previous tokens.
The architecture consists of stacked transformer blocks, each containing:
- Multi-head self-attention with causal masking
- Position-wise feed-forward network
- Layer normalization (applied before each sub-layer in modern variants)
- Residual connections around each sub-layer
We’ll implement GPT-2 style architecture with pre-normalization, which has become the standard approach for better training stability.
Multi-Head Self-Attention
Self-attention is the heart of GPT. It allows each token to attend to all previous tokens in the sequence, computing weighted representations based on relevance.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# Single matrix for Q, K, V projections (more efficient)
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# Causal mask buffer (registered as buffer, not parameter)
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(1024, 1024)).view(1, 1, 1024, 1024)
)
def forward(self, x):
batch_size, seq_len, d_model = x.shape
# Project to Q, K, V
qkv = self.qkv_proj(x) # (B, T, 3*d_model)
qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, T, d_k)
q, k, v = qkv[0], qkv[1], qkv[2]
# Scaled dot-product attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply causal mask (prevent attending to future tokens)
scores = scores.masked_fill(
self.causal_mask[:, :, :seq_len, :seq_len] == 0,
float('-inf')
)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, v) # (B, n_heads, T, d_k)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_len, d_model)
return self.out_proj(attn_output)
The causal mask is critical—it ensures each position can only attend to earlier positions, maintaining the autoregressive property during training. We use masked_fill with -inf so these positions get zero attention weight after softmax.
Transformer Block and Positional Encoding
Each transformer block applies attention, then a feed-forward network, with layer normalization and residual connections. Modern GPT uses pre-normalization (LayerNorm before each sub-layer) rather than post-normalization.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(), # GPT-2 uses GELU activation
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
def forward(self, x):
# Pre-normalization with residual connections
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=1024):
super().__init__()
# Learnable positional embeddings (GPT-2 style)
self.pos_embedding = nn.Embedding(max_len, d_model)
def forward(self, x):
batch_size, seq_len = x.shape
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
return self.pos_embedding(positions)
GPT-2 uses learnable positional embeddings rather than the sinusoidal encodings from the original Transformer paper. This is simpler and works just as well in practice.
Complete GPT Model
Now we assemble everything into the full GPT model:
class GPT(nn.Module):
def __init__(self, vocab_size, d_model=768, n_layers=12, n_heads=12,
d_ff=3072, max_len=1024, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model) # Final layer norm
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Tie weights between token embeddings and output projection
self.lm_head.weight = self.token_embedding.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
# idx: (batch_size, seq_len)
tok_emb = self.token_embedding(idx)
pos_emb = self.pos_encoding(idx)
x = tok_emb + pos_emb
for block in self.transformer_blocks:
x = block(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1)
)
return logits, loss
Weight tying between the token embedding and output projection is a common technique that reduces parameters and often improves performance. The initialization scheme follows GPT-2’s approach with small standard deviation.
Training and Text Generation
Here’s a minimal training setup with autoregressive generation:
def train_step(model, optimizer, data, targets):
model.train()
optimizer.zero_grad()
logits, loss = model(data, targets)
loss.backward()
# Gradient clipping for stability
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return loss.item()
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
model.eval()
for _ in range(max_new_tokens):
# Crop context if needed
idx_cond = idx if idx.size(1) <= 1024 else idx[:, -1024:]
# Get predictions
logits, _ = model(idx_cond)
logits = logits[:, -1, :] / temperature
# Optional top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Sample from distribution
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# Append to sequence
idx = torch.cat([idx, idx_next], dim=1)
return idx
Temperature controls randomness (lower = more deterministic), while top-k sampling restricts choices to the k most likely tokens, improving generation quality.
Testing on Real Data
Let’s train on a character-level dataset:
# Simple character-level tokenization
text = open('shakespeare.txt', 'r').read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
# Prepare data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
# Initialize model (smaller for demo)
model = GPT(
vocab_size=vocab_size,
d_model=256,
n_layers=6,
n_heads=8,
d_ff=1024,
max_len=256
).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# Training loop
block_size = 128
batch_size = 32
for step in range(5000):
# Sample batch
ix = torch.randint(len(train_data) - block_size, (batch_size,))
x = torch.stack([train_data[i:i+block_size] for i in ix]).cuda()
y = torch.stack([train_data[i+1:i+block_size+1] for i in ix]).cuda()
loss = train_step(model, optimizer, x, y)
if step % 500 == 0:
print(f"Step {step}, Loss: {loss:.4f}")
# Generate sample
context = torch.tensor([encode("To be or not to be")], dtype=torch.long).cuda()
generated = generate(model, context, max_new_tokens=200, temperature=0.8, top_k=40)
print(decode(generated[0].tolist()))
After 5000 steps, you’ll see the model generating Shakespeare-like text with proper structure, though character-level models produce more spelling errors than subword tokenizers.
Practical Considerations
This implementation is educational but production systems need additional features: gradient checkpointing for memory efficiency, mixed precision training with automatic mixed precision (AMP), Flash Attention for faster computation, and proper data loading with distributed training support.
The architecture scales predictably—GPT-3 uses the same components with d_model=12288, n_layers=96, and n_heads=96. The core implementation remains identical; only hyperparameters change.
For real applications, use the Hugging Face Transformers library which provides optimized implementations, pretrained weights, and tokenizers. But understanding this from-scratch implementation helps you debug issues, customize architectures, and truly understand what’s happening under the hood.