How to Implement BERT in PyTorch

BERT (Bidirectional Encoder Representations from Transformers) fundamentally changed how we approach NLP tasks. Unlike GPT's left-to-right architecture or ELMo's shallow bidirectionality, BERT reads...

Key Insights

  • BERT’s bidirectional architecture processes text in both directions simultaneously, enabling richer context understanding than previous unidirectional models—this makes it ideal for tasks requiring deep semantic comprehension like question answering and named entity recognition.
  • You don’t need to build BERT from scratch for most applications; Hugging Face’s transformers library provides battle-tested implementations that you can fine-tune in under 100 lines of code, saving weeks of development time.
  • Fine-tuning BERT requires careful memory management—a base BERT model with batch size 16 consumes roughly 10GB of GPU memory, making gradient accumulation and mixed precision training essential for resource-constrained environments.

Introduction to BERT Architecture

BERT (Bidirectional Encoder Representations from Transformers) fundamentally changed how we approach NLP tasks. Unlike GPT’s left-to-right architecture or ELMo’s shallow bidirectionality, BERT reads text in both directions simultaneously through its masked language modeling objective. This bidirectional context allows BERT to understand nuanced relationships between words that unidirectional models miss.

The architecture consists of stacked transformer encoder layers. Each layer contains multi-head self-attention mechanisms and position-wise feed-forward networks. BERT-base has 12 layers with 768 hidden dimensions and 12 attention heads, totaling 110M parameters. BERT-large doubles this to 24 layers, 1024 hidden dimensions, and 340M parameters.

BERT’s pre-training uses two objectives: Masked Language Modeling (MLM) masks 15% of input tokens and trains the model to predict them, forcing it to learn bidirectional representations. Next Sentence Prediction (NSP) trains BERT to understand sentence relationships, though recent research suggests NSP’s contribution is debatable.

Setting Up the Environment

Install the required dependencies. We’ll use PyTorch with Hugging Face’s transformers library, which provides pre-trained BERT models and utilities.

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install transformers==4.35.0
pip install datasets==2.14.0
pip install scikit-learn

Import the necessary modules and verify your setup:

import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, BertConfig
from transformers import AdamW, get_linear_schedule_with_warmup
import numpy as np
from sklearn.metrics import accuracy_score, f1_score

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

Loading Pre-trained BERT Models

Hugging Face makes loading pre-trained BERT models trivial. The library handles downloading, caching, and initialization automatically.

# Load pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# Tokenize sample text
text = "BERT revolutionized natural language processing."
encoded = tokenizer(
    text,
    padding='max_length',
    max_length=128,
    truncation=True,
    return_tensors='pt'
)

print(f"Input IDs shape: {encoded['input_ids'].shape}")
print(f"Attention mask shape: {encoded['attention_mask'].shape}")
print(f"Tokens: {tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])}")

# Get model outputs
with torch.no_grad():
    outputs = model(**encoded)
    last_hidden_state = outputs.last_hidden_state
    pooled_output = outputs.pooler_output

print(f"Last hidden state shape: {last_hidden_state.shape}")  # [batch_size, seq_len, hidden_size]
print(f"Pooled output shape: {pooled_output.shape}")  # [batch_size, hidden_size]

The tokenizer converts text to input IDs, adds special tokens ([CLS], [SEP]), and creates attention masks. The model returns two key outputs: last_hidden_state contains embeddings for all tokens, while pooler_output provides a fixed-size representation suitable for classification tasks.

Building BERT from Scratch

Understanding BERT’s internals helps debug issues and customize architectures. Here’s a simplified implementation of core components:

class BertEmbeddings(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.token_type_embeddings = nn.Embedding(2, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        
        word_embeds = self.word_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        token_type_embeds = self.token_type_embeddings(token_type_ids)
        
        embeddings = word_embeds + position_embeds + token_type_embeds
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.output = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, hidden_states, attention_mask=None):
        batch_size = hidden_states.size(0)
        
        # Linear projections and reshape for multi-head attention
        Q = self.query(hidden_states).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(hidden_states).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(hidden_states).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, -1e9)
        
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        context = torch.matmul(attention_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.output(context)
        return output

Fine-tuning BERT for Classification

Fine-tuning adapts pre-trained BERT to specific tasks. Here’s a complete sentiment analysis implementation:

class BertForSentimentClassification(nn.Module):
    def __init__(self, num_classes=2, dropout=0.3):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits


# Training loop
def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        logits = model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)


# Initialize model and training components
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertForSentimentClassification(num_classes=2).to(device)

optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
epochs = 3
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

Use a small learning rate (2e-5 to 5e-5) for fine-tuning. BERT’s pre-trained weights are already optimized, so aggressive learning rates destroy learned representations.

Inference and Model Deployment

After training, implement efficient inference and model persistence:

def predict(model, text, tokenizer, device, max_length=128):
    model.eval()
    
    encoded = tokenizer(
        text,
        padding='max_length',
        max_length=max_length,
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoded['input_ids'].to(device)
    attention_mask = encoded['attention_mask'].to(device)
    
    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probabilities = torch.softmax(logits, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1)
    
    return predicted_class.item(), probabilities.cpu().numpy()


# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
}, 'bert_sentiment_model.pt')

# Load model
checkpoint = torch.load('bert_sentiment_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Batch inference
def batch_predict(model, texts, tokenizer, device, batch_size=32):
    model.eval()
    predictions = []
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        encoded = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors='pt'
        )
        
        with torch.no_grad():
            logits = model(
                encoded['input_ids'].to(device),
                encoded['attention_mask'].to(device)
            )
            batch_predictions = torch.argmax(logits, dim=1)
            predictions.extend(batch_predictions.cpu().numpy())
    
    return predictions

Performance Optimization Tips

BERT’s memory footprint requires optimization for production environments:

# Mixed precision training with automatic mixed precision
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()
    
    with autocast():
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

# Gradient accumulation for larger effective batch sizes
accumulation_steps = 4

for i, batch in enumerate(dataloader):
    logits = model(input_ids, attention_mask)
    loss = criterion(logits, labels) / accumulation_steps
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Mixed precision training reduces memory usage by 40-50% and speeds up training by 2-3x on modern GPUs. Gradient accumulation simulates larger batch sizes without increasing memory consumption—critical when GPU memory limits batch size.

BERT implementation in PyTorch is straightforward with the transformers library. Focus on proper fine-tuning techniques, memory optimization, and efficient inference patterns. Start with pre-trained models, fine-tune on your specific task, and optimize only when performance becomes a bottleneck.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.