How to Implement Word Embeddings in PyTorch
Word embeddings transform discrete words into continuous vector representations that capture semantic relationships. Unlike one-hot encoding, which creates sparse vectors with no notion of...
Key Insights
- PyTorch’s
nn.Embeddinglayer functions as a trainable lookup table that maps integer token IDs to dense vector representations, making it the foundation for most NLP models - Training custom embeddings from scratch using Skip-gram requires building a vocabulary, generating context-target pairs, and optimizing embeddings through negative sampling
- Pre-trained embeddings like GloVe can be loaded into PyTorch models and either frozen for feature extraction or fine-tuned on domain-specific data for better performance
Introduction to Word Embeddings
Word embeddings transform discrete words into continuous vector representations that capture semantic relationships. Unlike one-hot encoding, which creates sparse vectors with no notion of similarity, embeddings place semantically similar words closer together in vector space. This means “king” and “queen” have similar representations, while “king” and “banana” are distant.
These dense representations power modern NLP tasks: sentiment analysis, machine translation, named entity recognition, and question answering. Popular pre-trained embeddings include Word2Vec (trained on Google News), GloVe (trained on web crawl data), and FastText (which handles subword information). PyTorch makes working with embeddings straightforward, whether you’re training from scratch or using pre-trained weights.
PyTorch’s nn.Embedding Layer
The nn.Embedding layer is a simple lookup table that stores embeddings for a fixed dictionary of words. It takes two primary parameters: num_embeddings (vocabulary size) and embedding_dim (vector dimensionality). When you pass token indices to this layer, it returns the corresponding embedding vectors.
import torch
import torch.nn as nn
# Create embedding layer: 10,000 words, 300-dimensional vectors
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=300)
# Lookup embeddings for specific token indices
token_ids = torch.LongTensor([5, 142, 8, 3341])
embedded = embedding(token_ids)
print(embedded.shape) # torch.Size([4, 300])
# Batch processing: [batch_size, sequence_length]
batch_ids = torch.LongTensor([[5, 142, 8], [3341, 22, 7]])
batch_embedded = embedding(batch_ids)
print(batch_embedded.shape) # torch.Size([2, 3, 300])
The embedding weights are randomly initialized and updated during backpropagation. Each row in the weight matrix represents one word’s embedding vector.
Building Embeddings from Scratch
To train custom embeddings, you need a vocabulary mapping words to unique indices. Here’s a practical implementation:
from collections import Counter
import re
class Vocabulary:
def __init__(self, min_freq=2):
self.min_freq = min_freq
self.word2idx = {"<PAD>": 0, "<UNK>": 1}
self.idx2word = {0: "<PAD>", 1: "<UNK>"}
self.word_counts = Counter()
def build_vocab(self, texts):
# Count word frequencies
for text in texts:
tokens = self.tokenize(text)
self.word_counts.update(tokens)
# Add words meeting minimum frequency
idx = 2
for word, count in self.word_counts.items():
if count >= self.min_freq:
self.word2idx[word] = idx
self.idx2word[idx] = word
idx += 1
def tokenize(self, text):
# Simple tokenization (improve with spaCy/NLTK for production)
text = text.lower()
tokens = re.findall(r'\b\w+\b', text)
return tokens
def encode(self, text):
tokens = self.tokenize(text)
return [self.word2idx.get(token, 1) for token in tokens]
def decode(self, indices):
return [self.idx2word.get(idx, "<UNK>") for idx in indices]
def __len__(self):
return len(self.word2idx)
# Usage example
texts = [
"The quick brown fox jumps over the lazy dog",
"The dog sleeps under the tree",
"A quick fox runs through the forest"
]
vocab = Vocabulary(min_freq=1)
vocab.build_vocab(texts)
print(f"Vocabulary size: {len(vocab)}")
encoded = vocab.encode("The fox runs")
print(f"Encoded: {encoded}")
print(f"Decoded: {vocab.decode(encoded)}")
This vocabulary class handles tokenization, builds word-to-index mappings, and provides encoding/decoding utilities. The <PAD> token handles variable-length sequences, while <UNK> represents out-of-vocabulary words.
Training Custom Embeddings with Skip-gram
Skip-gram predicts context words given a target word. For the sentence “the quick brown fox,” with target word “brown” and window size 2, we generate pairs: (brown, the), (brown, quick), (brown, fox).
import torch
from torch.utils.data import Dataset, DataLoader
class SkipGramDataset(Dataset):
def __init__(self, texts, vocab, window_size=2):
self.pairs = []
self.vocab = vocab
for text in texts:
tokens = vocab.encode(text)
for i, target in enumerate(tokens):
# Get context indices within window
start = max(0, i - window_size)
end = min(len(tokens), i + window_size + 1)
for j in range(start, end):
if i != j:
self.pairs.append((target, tokens[j]))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
target, context = self.pairs[idx]
return torch.LongTensor([target]), torch.LongTensor([context])
class SkipGramModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.target_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
def forward(self, target, context):
target_embed = self.target_embeddings(target)
context_embed = self.context_embeddings(context)
# Compute dot product similarity
score = torch.sum(target_embed * context_embed, dim=1)
return score
# Training setup
dataset = SkipGramDataset(texts, vocab, window_size=2)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model = SkipGramModel(vocab_size=len(vocab), embedding_dim=100)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()
# Training loop
for epoch in range(10):
total_loss = 0
for target, context in dataloader:
optimizer.zero_grad()
# Positive samples
pos_score = model(target, context)
pos_labels = torch.ones_like(pos_score)
# Negative sampling (simplified)
neg_context = torch.randint(0, len(vocab), context.shape)
neg_score = model(target, neg_context)
neg_labels = torch.zeros_like(neg_score)
# Combined loss
scores = torch.cat([pos_score, neg_score])
labels = torch.cat([pos_labels, neg_labels])
loss = criterion(scores, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
This implementation uses two embedding matrices: one for target words and one for context words. The model learns by maximizing similarity between actual context pairs while minimizing similarity with random negative samples.
Using Pre-trained Embeddings
Loading pre-trained embeddings gives you high-quality representations without extensive training. Here’s how to load GloVe embeddings:
import numpy as np
def load_glove_embeddings(file_path, vocab, embedding_dim=300):
# Initialize with random embeddings
embeddings = np.random.randn(len(vocab), embedding_dim) * 0.01
# Load GloVe vectors
found = 0
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.split()
word = parts[0]
if word in vocab.word2idx:
vector = np.array([float(x) for x in parts[1:]])
embeddings[vocab.word2idx[word]] = vector
found += 1
print(f"Loaded {found}/{len(vocab)} word vectors")
return torch.FloatTensor(embeddings)
# Create embedding layer with pre-trained weights
pretrained_embeddings = load_glove_embeddings('glove.6B.300d.txt', vocab)
embedding_layer = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False)
# freeze=True: embeddings won't be updated during training
# freeze=False: fine-tune embeddings on your specific task
Practical Application: Sentiment Classifier
Here’s a complete sentiment analysis model using embeddings:
class SentimentClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
# x: [batch_size, seq_len]
embedded = self.embedding(x) # [batch_size, seq_len, embedding_dim]
embedded = self.dropout(embedded)
lstm_out, (hidden, cell) = self.lstm(embedded)
# Concatenate final forward and backward hidden states
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
hidden = self.dropout(hidden)
output = self.fc(hidden)
return output
# Initialize model
model = SentimentClassifier(
vocab_size=len(vocab),
embedding_dim=300,
hidden_dim=128,
num_classes=2
)
# Optionally load pre-trained embeddings
model.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False)
Best Practices and Optimization
Embedding Dimensions: Use 50-300 dimensions for most tasks. Larger vocabularies benefit from higher dimensions, but diminishing returns occur beyond 300.
Padding Strategy: Pad sequences to the same length for efficient batching. Use padding_idx=0 in nn.Embedding to ignore padding tokens:
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
def collate_batch(batch):
# batch: list of (text_indices, label) tuples
texts, labels = zip(*batch)
# Convert to tensors and get lengths
texts = [torch.LongTensor(text) for text in texts]
lengths = torch.LongTensor([len(text) for text in texts])
# Pad sequences
padded_texts = pad_sequence(texts, batch_first=True, padding_value=0)
labels = torch.LongTensor(labels)
return padded_texts, labels, lengths
# In your model's forward pass:
def forward(self, x, lengths):
embedded = self.embedding(x)
# Pack padded sequences for efficient LSTM processing
packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
packed_out, (hidden, cell) = self.lstm(packed)
# Unpack if needed
output, _ = pad_packed_sequence(packed_out, batch_first=True)
return output
Memory Optimization: For large vocabularies, consider using sparse=True in the embedding layer, which uses sparse gradients and reduces memory usage during backpropagation.
Regularization: Apply dropout after the embedding layer to prevent overfitting. Weight decay on embedding parameters can also help, though it’s less common.
Word embeddings are the foundation of modern NLP. Whether training from scratch or using pre-trained weights, PyTorch’s nn.Embedding provides a flexible, efficient interface. Start with pre-trained embeddings for most tasks, and only train custom embeddings when you have domain-specific requirements or sufficient training data.