How to Implement Named Entity Recognition in PyTorch

Named Entity Recognition (NER) is a fundamental NLP task that identifies and classifies named entities in text into predefined categories like person names, organizations, locations, dates, and...

Key Insights

  • Named Entity Recognition requires careful handling of the BIO tagging scheme and padding tokens during training to avoid corrupting loss calculations
  • A BiLSTM architecture captures bidirectional context essential for disambiguating entity boundaries, significantly outperforming unidirectional models
  • Adding a CRF layer on top of BiLSTM outputs enforces valid tag transitions (like preventing I-PER following B-LOC), improving F1 scores by 2-4% in practice

Introduction to Named Entity Recognition

Named Entity Recognition (NER) is a fundamental NLP task that identifies and classifies named entities in text into predefined categories like person names, organizations, locations, dates, and monetary values. If you’ve ever wondered how Gmail automatically detects addresses to show map links or how chatbots extract relevant information from user queries, NER is doing the heavy lifting.

PyTorch excels at NER implementation because it provides fine-grained control over model architecture and training dynamics. Unlike higher-level frameworks that abstract away crucial details, PyTorch lets you handle the nuances of sequence labeling—like padding mask management and variable-length sequences—exactly how you need to.

The standard approach treats NER as a token-level classification problem where each word receives a label. We’ll build a production-ready BiLSTM model that achieves competitive performance on the CoNLL-2003 benchmark.

Dataset Preparation and Preprocessing

The CoNLL-2003 dataset is the de facto standard for NER evaluation. It uses the BIO tagging scheme: B-PER (beginning of person), I-PER (inside person), B-LOC, I-LOC, B-ORG, I-ORG, B-MISC, I-MISC, and O (outside any entity). This scheme handles multi-token entities cleanly.

Here’s a robust dataset implementation:

import torch
from torch.utils.data import Dataset
from collections import Counter
import numpy as np

class NERDataset(Dataset):
    def __init__(self, sentences, labels, word2idx, label2idx, max_len=128):
        self.sentences = sentences
        self.labels = labels
        self.word2idx = word2idx
        self.label2idx = label2idx
        self.max_len = max_len
        
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        words = self.sentences[idx]
        tags = self.labels[idx]
        
        # Convert to indices
        word_ids = [self.word2idx.get(w.lower(), self.word2idx['<UNK>']) 
                    for w in words[:self.max_len]]
        label_ids = [self.label2idx[t] for t in tags[:self.max_len]]
        
        # Pad sequences
        padding_len = self.max_len - len(word_ids)
        word_ids += [self.word2idx['<PAD>']] * padding_len
        label_ids += [self.label2idx['O']] * padding_len  # Pad labels as 'O'
        
        return {
            'input_ids': torch.tensor(word_ids, dtype=torch.long),
            'labels': torch.tensor(label_ids, dtype=torch.long),
            'attention_mask': torch.tensor([1] * len(words[:self.max_len]) + 
                                          [0] * padding_len, dtype=torch.long)
        }

def build_vocab(sentences, min_freq=2):
    """Build vocabulary from sentences."""
    word_counter = Counter()
    for sent in sentences:
        word_counter.update([w.lower() for w in sent])
    
    # Keep only words above threshold
    vocab = ['<PAD>', '<UNK>'] + [w for w, c in word_counter.items() 
                                   if c >= min_freq]
    word2idx = {w: i for i, w in enumerate(vocab)}
    return word2idx

# Example usage
sentences = [['John', 'lives', 'in', 'New', 'York'], 
             ['Apple', 'Inc.', 'is', 'a', 'company']]
labels = [['B-PER', 'O', 'O', 'B-LOC', 'I-LOC'],
          ['B-ORG', 'I-ORG', 'O', 'O', 'O']]

label2idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-LOC': 3, 
             'I-LOC': 4, 'B-ORG': 5, 'I-ORG': 6, 'B-MISC': 7, 'I-MISC': 8}
word2idx = build_vocab(sentences)

dataset = NERDataset(sentences, labels, word2idx, label2idx)

The attention mask is critical—it tells the model which tokens are real versus padding, preventing the model from learning spurious patterns from pad tokens.

Building the NER Model Architecture

A BiLSTM architecture is the sweet spot for NER: powerful enough to capture context, simple enough to train efficiently. The bidirectional nature is crucial because entity boundaries often depend on both preceding and following words.

import torch.nn as nn

class BiLSTM_NER(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_labels, 
                 num_layers=2, dropout=0.3):
        super(BiLSTM_NER, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, 
                                      padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers,
                           bidirectional=True, batch_first=True, 
                           dropout=dropout if num_layers > 1 else 0)
        self.dropout = nn.Dropout(dropout)
        
        # *2 because bidirectional
        self.classifier = nn.Linear(hidden_dim * 2, num_labels)
        
    def forward(self, input_ids, attention_mask=None):
        # input_ids: (batch_size, seq_len)
        embeddings = self.embedding(input_ids)  # (batch_size, seq_len, embedding_dim)
        embeddings = self.dropout(embeddings)
        
        # LSTM output
        lstm_out, _ = self.lstm(embeddings)  # (batch_size, seq_len, hidden_dim*2)
        lstm_out = self.dropout(lstm_out)
        
        # Classify each token
        logits = self.classifier(lstm_out)  # (batch_size, seq_len, num_labels)
        
        return logits

# Initialize model
model = BiLSTM_NER(
    vocab_size=len(word2idx),
    embedding_dim=100,
    hidden_dim=256,
    num_labels=len(label2idx),
    num_layers=2,
    dropout=0.3
)

The dropout layers are essential for regularization. Without them, the model will overfit on typical NER datasets which have limited training examples.

Training Loop Implementation

The training loop must handle padding correctly. A common mistake is computing loss over padding tokens, which dilutes the gradient signal and slows convergence.

from torch.utils.data import DataLoader
import torch.optim as optim

def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(input_ids, attention_mask)
        
        # Reshape for loss calculation
        # logits: (batch_size, seq_len, num_labels)
        # labels: (batch_size, seq_len)
        loss = compute_loss(logits, labels, attention_mask)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def compute_loss(logits, labels, attention_mask):
    """Compute cross-entropy loss ignoring padding tokens."""
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    
    # Flatten predictions and labels
    logits_flat = logits.view(-1, logits.size(-1))  # (batch*seq_len, num_labels)
    labels_flat = labels.view(-1)  # (batch*seq_len)
    
    # Compute loss
    loss = loss_fn(logits_flat, labels_flat)
    
    # Apply mask to ignore padding
    mask_flat = attention_mask.view(-1).float()
    loss = (loss * mask_flat).sum() / mask_flat.sum()
    
    return loss

# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    avg_loss = train_epoch(model, dataloader, optimizer, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

Gradient clipping prevents exploding gradients, a common issue with RNNs on variable-length sequences. The weight decay in AdamW provides additional regularization.

Evaluation and Prediction

NER evaluation requires entity-level metrics, not just token-level accuracy. A prediction is only correct if all tokens of an entity match exactly.

from seqeval.metrics import f1_score, classification_report

def evaluate(model, dataloader, idx2label, device):
    model.eval()
    predictions, true_labels = [], []
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']
            
            logits = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=-1).cpu().numpy()
            
            # Convert to label strings, filtering padding
            for i in range(len(preds)):
                pred_labels = []
                true_label_seq = []
                for j in range(len(preds[i])):
                    if attention_mask[i][j] == 1:
                        pred_labels.append(idx2label[preds[i][j]])
                        true_label_seq.append(idx2label[labels[i][j].item()])
                predictions.append(pred_labels)
                true_labels.append(true_label_seq)
    
    # Calculate F1 score
    f1 = f1_score(true_labels, predictions)
    report = classification_report(true_labels, predictions)
    
    return f1, report

def predict_entities(model, sentence, word2idx, idx2label, device):
    """Extract entities from raw text."""
    model.eval()
    
    # Tokenize and convert to indices
    words = sentence.split()
    word_ids = [word2idx.get(w.lower(), word2idx['<UNK>']) for w in words]
    input_ids = torch.tensor([word_ids], dtype=torch.long).to(device)
    
    with torch.no_grad():
        logits = model(input_ids)
        predictions = torch.argmax(logits, dim=-1)[0].cpu().numpy()
    
    # Extract entities
    entities = []
    current_entity = []
    current_label = None
    
    for word, pred_idx in zip(words, predictions):
        label = idx2label[pred_idx]
        
        if label.startswith('B-'):
            if current_entity:
                entities.append((' '.join(current_entity), current_label))
            current_entity = [word]
            current_label = label[2:]
        elif label.startswith('I-') and current_label == label[2:]:
            current_entity.append(word)
        else:
            if current_entity:
                entities.append((' '.join(current_entity), current_label))
            current_entity = []
            current_label = None
    
    if current_entity:
        entities.append((' '.join(current_entity), current_label))
    
    return entities

# Example usage
idx2label = {v: k for k, v in label2idx.items()}
text = "Apple Inc. was founded by Steve Jobs in California"
entities = predict_entities(model, text, word2idx, idx2label, device)
print(entities)  # [('Apple Inc.', 'ORG'), ('Steve Jobs', 'PER'), ('California', 'LOC')]

Practical Tips and Improvements

The BiLSTM model is solid, but adding a CRF layer enforces valid tag transitions. For example, I-PER should never follow B-LOC. This constraint improves F1 by 2-4%:

from torchcrf import CRF

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_labels):
        super(BiLSTM_CRF, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, 
                           batch_first=True)
        self.classifier = nn.Linear(hidden_dim * 2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)
        
    def forward(self, input_ids, labels=None, attention_mask=None):
        embeddings = self.embedding(input_ids)
        lstm_out, _ = self.lstm(embeddings)
        emissions = self.classifier(lstm_out)
        
        if labels is not None:
            # Training: return negative log-likelihood
            loss = -self.crf(emissions, labels, mask=attention_mask.bool())
            return loss
        else:
            # Inference: return best path
            predictions = self.crf.decode(emissions, mask=attention_mask.bool())
            return predictions

For production systems, consider using pre-trained transformer models like BERT or RoBERTa through HuggingFace. They achieve 3-5% higher F1 scores but require significantly more compute. The BiLSTM approach demonstrated here offers the best balance of performance and efficiency for most applications.

Always validate on a held-out test set from a different domain than your training data. NER models often overfit to specific writing styles and entity distributions.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.