How to Implement Named Entity Recognition in PyTorch
Named Entity Recognition (NER) is a fundamental NLP task that identifies and classifies named entities in text into predefined categories like person names, organizations, locations, dates, and...
Key Insights
- Named Entity Recognition requires careful handling of the BIO tagging scheme and padding tokens during training to avoid corrupting loss calculations
- A BiLSTM architecture captures bidirectional context essential for disambiguating entity boundaries, significantly outperforming unidirectional models
- Adding a CRF layer on top of BiLSTM outputs enforces valid tag transitions (like preventing I-PER following B-LOC), improving F1 scores by 2-4% in practice
Introduction to Named Entity Recognition
Named Entity Recognition (NER) is a fundamental NLP task that identifies and classifies named entities in text into predefined categories like person names, organizations, locations, dates, and monetary values. If you’ve ever wondered how Gmail automatically detects addresses to show map links or how chatbots extract relevant information from user queries, NER is doing the heavy lifting.
PyTorch excels at NER implementation because it provides fine-grained control over model architecture and training dynamics. Unlike higher-level frameworks that abstract away crucial details, PyTorch lets you handle the nuances of sequence labeling—like padding mask management and variable-length sequences—exactly how you need to.
The standard approach treats NER as a token-level classification problem where each word receives a label. We’ll build a production-ready BiLSTM model that achieves competitive performance on the CoNLL-2003 benchmark.
Dataset Preparation and Preprocessing
The CoNLL-2003 dataset is the de facto standard for NER evaluation. It uses the BIO tagging scheme: B-PER (beginning of person), I-PER (inside person), B-LOC, I-LOC, B-ORG, I-ORG, B-MISC, I-MISC, and O (outside any entity). This scheme handles multi-token entities cleanly.
Here’s a robust dataset implementation:
import torch
from torch.utils.data import Dataset
from collections import Counter
import numpy as np
class NERDataset(Dataset):
def __init__(self, sentences, labels, word2idx, label2idx, max_len=128):
self.sentences = sentences
self.labels = labels
self.word2idx = word2idx
self.label2idx = label2idx
self.max_len = max_len
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
words = self.sentences[idx]
tags = self.labels[idx]
# Convert to indices
word_ids = [self.word2idx.get(w.lower(), self.word2idx['<UNK>'])
for w in words[:self.max_len]]
label_ids = [self.label2idx[t] for t in tags[:self.max_len]]
# Pad sequences
padding_len = self.max_len - len(word_ids)
word_ids += [self.word2idx['<PAD>']] * padding_len
label_ids += [self.label2idx['O']] * padding_len # Pad labels as 'O'
return {
'input_ids': torch.tensor(word_ids, dtype=torch.long),
'labels': torch.tensor(label_ids, dtype=torch.long),
'attention_mask': torch.tensor([1] * len(words[:self.max_len]) +
[0] * padding_len, dtype=torch.long)
}
def build_vocab(sentences, min_freq=2):
"""Build vocabulary from sentences."""
word_counter = Counter()
for sent in sentences:
word_counter.update([w.lower() for w in sent])
# Keep only words above threshold
vocab = ['<PAD>', '<UNK>'] + [w for w, c in word_counter.items()
if c >= min_freq]
word2idx = {w: i for i, w in enumerate(vocab)}
return word2idx
# Example usage
sentences = [['John', 'lives', 'in', 'New', 'York'],
['Apple', 'Inc.', 'is', 'a', 'company']]
labels = [['B-PER', 'O', 'O', 'B-LOC', 'I-LOC'],
['B-ORG', 'I-ORG', 'O', 'O', 'O']]
label2idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-LOC': 3,
'I-LOC': 4, 'B-ORG': 5, 'I-ORG': 6, 'B-MISC': 7, 'I-MISC': 8}
word2idx = build_vocab(sentences)
dataset = NERDataset(sentences, labels, word2idx, label2idx)
The attention mask is critical—it tells the model which tokens are real versus padding, preventing the model from learning spurious patterns from pad tokens.
Building the NER Model Architecture
A BiLSTM architecture is the sweet spot for NER: powerful enough to capture context, simple enough to train efficiently. The bidirectional nature is crucial because entity boundaries often depend on both preceding and following words.
import torch.nn as nn
class BiLSTM_NER(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_labels,
num_layers=2, dropout=0.3):
super(BiLSTM_NER, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim,
padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers,
bidirectional=True, batch_first=True,
dropout=dropout if num_layers > 1 else 0)
self.dropout = nn.Dropout(dropout)
# *2 because bidirectional
self.classifier = nn.Linear(hidden_dim * 2, num_labels)
def forward(self, input_ids, attention_mask=None):
# input_ids: (batch_size, seq_len)
embeddings = self.embedding(input_ids) # (batch_size, seq_len, embedding_dim)
embeddings = self.dropout(embeddings)
# LSTM output
lstm_out, _ = self.lstm(embeddings) # (batch_size, seq_len, hidden_dim*2)
lstm_out = self.dropout(lstm_out)
# Classify each token
logits = self.classifier(lstm_out) # (batch_size, seq_len, num_labels)
return logits
# Initialize model
model = BiLSTM_NER(
vocab_size=len(word2idx),
embedding_dim=100,
hidden_dim=256,
num_labels=len(label2idx),
num_layers=2,
dropout=0.3
)
The dropout layers are essential for regularization. Without them, the model will overfit on typical NER datasets which have limited training examples.
Training Loop Implementation
The training loop must handle padding correctly. A common mistake is computing loss over padding tokens, which dilutes the gradient signal and slows convergence.
from torch.utils.data import DataLoader
import torch.optim as optim
def train_epoch(model, dataloader, optimizer, device):
model.train()
total_loss = 0
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
attention_mask = batch['attention_mask'].to(device)
optimizer.zero_grad()
# Forward pass
logits = model(input_ids, attention_mask)
# Reshape for loss calculation
# logits: (batch_size, seq_len, num_labels)
# labels: (batch_size, seq_len)
loss = compute_loss(logits, labels, attention_mask)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def compute_loss(logits, labels, attention_mask):
"""Compute cross-entropy loss ignoring padding tokens."""
loss_fn = nn.CrossEntropyLoss(reduction='none')
# Flatten predictions and labels
logits_flat = logits.view(-1, logits.size(-1)) # (batch*seq_len, num_labels)
labels_flat = labels.view(-1) # (batch*seq_len)
# Compute loss
loss = loss_fn(logits_flat, labels_flat)
# Apply mask to ignore padding
mask_flat = attention_mask.view(-1).float()
loss = (loss * mask_flat).sum() / mask_flat.sum()
return loss
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Training loop
num_epochs = 20
for epoch in range(num_epochs):
avg_loss = train_epoch(model, dataloader, optimizer, device)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
Gradient clipping prevents exploding gradients, a common issue with RNNs on variable-length sequences. The weight decay in AdamW provides additional regularization.
Evaluation and Prediction
NER evaluation requires entity-level metrics, not just token-level accuracy. A prediction is only correct if all tokens of an entity match exactly.
from seqeval.metrics import f1_score, classification_report
def evaluate(model, dataloader, idx2label, device):
model.eval()
predictions, true_labels = [], []
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels']
logits = model(input_ids, attention_mask)
preds = torch.argmax(logits, dim=-1).cpu().numpy()
# Convert to label strings, filtering padding
for i in range(len(preds)):
pred_labels = []
true_label_seq = []
for j in range(len(preds[i])):
if attention_mask[i][j] == 1:
pred_labels.append(idx2label[preds[i][j]])
true_label_seq.append(idx2label[labels[i][j].item()])
predictions.append(pred_labels)
true_labels.append(true_label_seq)
# Calculate F1 score
f1 = f1_score(true_labels, predictions)
report = classification_report(true_labels, predictions)
return f1, report
def predict_entities(model, sentence, word2idx, idx2label, device):
"""Extract entities from raw text."""
model.eval()
# Tokenize and convert to indices
words = sentence.split()
word_ids = [word2idx.get(w.lower(), word2idx['<UNK>']) for w in words]
input_ids = torch.tensor([word_ids], dtype=torch.long).to(device)
with torch.no_grad():
logits = model(input_ids)
predictions = torch.argmax(logits, dim=-1)[0].cpu().numpy()
# Extract entities
entities = []
current_entity = []
current_label = None
for word, pred_idx in zip(words, predictions):
label = idx2label[pred_idx]
if label.startswith('B-'):
if current_entity:
entities.append((' '.join(current_entity), current_label))
current_entity = [word]
current_label = label[2:]
elif label.startswith('I-') and current_label == label[2:]:
current_entity.append(word)
else:
if current_entity:
entities.append((' '.join(current_entity), current_label))
current_entity = []
current_label = None
if current_entity:
entities.append((' '.join(current_entity), current_label))
return entities
# Example usage
idx2label = {v: k for k, v in label2idx.items()}
text = "Apple Inc. was founded by Steve Jobs in California"
entities = predict_entities(model, text, word2idx, idx2label, device)
print(entities) # [('Apple Inc.', 'ORG'), ('Steve Jobs', 'PER'), ('California', 'LOC')]
Practical Tips and Improvements
The BiLSTM model is solid, but adding a CRF layer enforces valid tag transitions. For example, I-PER should never follow B-LOC. This constraint improves F1 by 2-4%:
from torchcrf import CRF
class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_labels):
super(BiLSTM_CRF, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True,
batch_first=True)
self.classifier = nn.Linear(hidden_dim * 2, num_labels)
self.crf = CRF(num_labels, batch_first=True)
def forward(self, input_ids, labels=None, attention_mask=None):
embeddings = self.embedding(input_ids)
lstm_out, _ = self.lstm(embeddings)
emissions = self.classifier(lstm_out)
if labels is not None:
# Training: return negative log-likelihood
loss = -self.crf(emissions, labels, mask=attention_mask.bool())
return loss
else:
# Inference: return best path
predictions = self.crf.decode(emissions, mask=attention_mask.bool())
return predictions
For production systems, consider using pre-trained transformer models like BERT or RoBERTa through HuggingFace. They achieve 3-5% higher F1 scores but require significantly more compute. The BiLSTM approach demonstrated here offers the best balance of performance and efficiency for most applications.
Always validate on a held-out test set from a different domain than your training data. NER models often overfit to specific writing styles and entity distributions.