How to Implement Word Embeddings in PyTorch

Word embeddings transform discrete words into continuous vector representations that capture semantic relationships. Unlike one-hot encoding, which creates sparse vectors with no notion of...

Key Insights

  • PyTorch’s nn.Embedding layer functions as a trainable lookup table that maps integer token IDs to dense vector representations, making it the foundation for most NLP models
  • Training custom embeddings from scratch using Skip-gram requires building a vocabulary, generating context-target pairs, and optimizing embeddings through negative sampling
  • Pre-trained embeddings like GloVe can be loaded into PyTorch models and either frozen for feature extraction or fine-tuned on domain-specific data for better performance

Introduction to Word Embeddings

Word embeddings transform discrete words into continuous vector representations that capture semantic relationships. Unlike one-hot encoding, which creates sparse vectors with no notion of similarity, embeddings place semantically similar words closer together in vector space. This means “king” and “queen” have similar representations, while “king” and “banana” are distant.

These dense representations power modern NLP tasks: sentiment analysis, machine translation, named entity recognition, and question answering. Popular pre-trained embeddings include Word2Vec (trained on Google News), GloVe (trained on web crawl data), and FastText (which handles subword information). PyTorch makes working with embeddings straightforward, whether you’re training from scratch or using pre-trained weights.

PyTorch’s nn.Embedding Layer

The nn.Embedding layer is a simple lookup table that stores embeddings for a fixed dictionary of words. It takes two primary parameters: num_embeddings (vocabulary size) and embedding_dim (vector dimensionality). When you pass token indices to this layer, it returns the corresponding embedding vectors.

import torch
import torch.nn as nn

# Create embedding layer: 10,000 words, 300-dimensional vectors
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=300)

# Lookup embeddings for specific token indices
token_ids = torch.LongTensor([5, 142, 8, 3341])
embedded = embedding(token_ids)
print(embedded.shape)  # torch.Size([4, 300])

# Batch processing: [batch_size, sequence_length]
batch_ids = torch.LongTensor([[5, 142, 8], [3341, 22, 7]])
batch_embedded = embedding(batch_ids)
print(batch_embedded.shape)  # torch.Size([2, 3, 300])

The embedding weights are randomly initialized and updated during backpropagation. Each row in the weight matrix represents one word’s embedding vector.

Building Embeddings from Scratch

To train custom embeddings, you need a vocabulary mapping words to unique indices. Here’s a practical implementation:

from collections import Counter
import re

class Vocabulary:
    def __init__(self, min_freq=2):
        self.min_freq = min_freq
        self.word2idx = {"<PAD>": 0, "<UNK>": 1}
        self.idx2word = {0: "<PAD>", 1: "<UNK>"}
        self.word_counts = Counter()
        
    def build_vocab(self, texts):
        # Count word frequencies
        for text in texts:
            tokens = self.tokenize(text)
            self.word_counts.update(tokens)
        
        # Add words meeting minimum frequency
        idx = 2
        for word, count in self.word_counts.items():
            if count >= self.min_freq:
                self.word2idx[word] = idx
                self.idx2word[idx] = word
                idx += 1
    
    def tokenize(self, text):
        # Simple tokenization (improve with spaCy/NLTK for production)
        text = text.lower()
        tokens = re.findall(r'\b\w+\b', text)
        return tokens
    
    def encode(self, text):
        tokens = self.tokenize(text)
        return [self.word2idx.get(token, 1) for token in tokens]
    
    def decode(self, indices):
        return [self.idx2word.get(idx, "<UNK>") for idx in indices]
    
    def __len__(self):
        return len(self.word2idx)

# Usage example
texts = [
    "The quick brown fox jumps over the lazy dog",
    "The dog sleeps under the tree",
    "A quick fox runs through the forest"
]

vocab = Vocabulary(min_freq=1)
vocab.build_vocab(texts)
print(f"Vocabulary size: {len(vocab)}")

encoded = vocab.encode("The fox runs")
print(f"Encoded: {encoded}")
print(f"Decoded: {vocab.decode(encoded)}")

This vocabulary class handles tokenization, builds word-to-index mappings, and provides encoding/decoding utilities. The <PAD> token handles variable-length sequences, while <UNK> represents out-of-vocabulary words.

Training Custom Embeddings with Skip-gram

Skip-gram predicts context words given a target word. For the sentence “the quick brown fox,” with target word “brown” and window size 2, we generate pairs: (brown, the), (brown, quick), (brown, fox).

import torch
from torch.utils.data import Dataset, DataLoader

class SkipGramDataset(Dataset):
    def __init__(self, texts, vocab, window_size=2):
        self.pairs = []
        self.vocab = vocab
        
        for text in texts:
            tokens = vocab.encode(text)
            for i, target in enumerate(tokens):
                # Get context indices within window
                start = max(0, i - window_size)
                end = min(len(tokens), i + window_size + 1)
                
                for j in range(start, end):
                    if i != j:
                        self.pairs.append((target, tokens[j]))
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        target, context = self.pairs[idx]
        return torch.LongTensor([target]), torch.LongTensor([context])

class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.target_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
        
    def forward(self, target, context):
        target_embed = self.target_embeddings(target)
        context_embed = self.context_embeddings(context)
        
        # Compute dot product similarity
        score = torch.sum(target_embed * context_embed, dim=1)
        return score

# Training setup
dataset = SkipGramDataset(texts, vocab, window_size=2)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = SkipGramModel(vocab_size=len(vocab), embedding_dim=100)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()

# Training loop
for epoch in range(10):
    total_loss = 0
    for target, context in dataloader:
        optimizer.zero_grad()
        
        # Positive samples
        pos_score = model(target, context)
        pos_labels = torch.ones_like(pos_score)
        
        # Negative sampling (simplified)
        neg_context = torch.randint(0, len(vocab), context.shape)
        neg_score = model(target, neg_context)
        neg_labels = torch.zeros_like(neg_score)
        
        # Combined loss
        scores = torch.cat([pos_score, neg_score])
        labels = torch.cat([pos_labels, neg_labels])
        loss = criterion(scores, labels)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

This implementation uses two embedding matrices: one for target words and one for context words. The model learns by maximizing similarity between actual context pairs while minimizing similarity with random negative samples.

Using Pre-trained Embeddings

Loading pre-trained embeddings gives you high-quality representations without extensive training. Here’s how to load GloVe embeddings:

import numpy as np

def load_glove_embeddings(file_path, vocab, embedding_dim=300):
    # Initialize with random embeddings
    embeddings = np.random.randn(len(vocab), embedding_dim) * 0.01
    
    # Load GloVe vectors
    found = 0
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.split()
            word = parts[0]
            if word in vocab.word2idx:
                vector = np.array([float(x) for x in parts[1:]])
                embeddings[vocab.word2idx[word]] = vector
                found += 1
    
    print(f"Loaded {found}/{len(vocab)} word vectors")
    return torch.FloatTensor(embeddings)

# Create embedding layer with pre-trained weights
pretrained_embeddings = load_glove_embeddings('glove.6B.300d.txt', vocab)
embedding_layer = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False)

# freeze=True: embeddings won't be updated during training
# freeze=False: fine-tune embeddings on your specific task

Practical Application: Sentiment Classifier

Here’s a complete sentiment analysis model using embeddings:

class SentimentClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        # x: [batch_size, seq_len]
        embedded = self.embedding(x)  # [batch_size, seq_len, embedding_dim]
        embedded = self.dropout(embedded)
        
        lstm_out, (hidden, cell) = self.lstm(embedded)
        # Concatenate final forward and backward hidden states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        hidden = self.dropout(hidden)
        
        output = self.fc(hidden)
        return output

# Initialize model
model = SentimentClassifier(
    vocab_size=len(vocab),
    embedding_dim=300,
    hidden_dim=128,
    num_classes=2
)

# Optionally load pre-trained embeddings
model.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False)

Best Practices and Optimization

Embedding Dimensions: Use 50-300 dimensions for most tasks. Larger vocabularies benefit from higher dimensions, but diminishing returns occur beyond 300.

Padding Strategy: Pad sequences to the same length for efficient batching. Use padding_idx=0 in nn.Embedding to ignore padding tokens:

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

def collate_batch(batch):
    # batch: list of (text_indices, label) tuples
    texts, labels = zip(*batch)
    
    # Convert to tensors and get lengths
    texts = [torch.LongTensor(text) for text in texts]
    lengths = torch.LongTensor([len(text) for text in texts])
    
    # Pad sequences
    padded_texts = pad_sequence(texts, batch_first=True, padding_value=0)
    labels = torch.LongTensor(labels)
    
    return padded_texts, labels, lengths

# In your model's forward pass:
def forward(self, x, lengths):
    embedded = self.embedding(x)
    
    # Pack padded sequences for efficient LSTM processing
    packed = pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
    packed_out, (hidden, cell) = self.lstm(packed)
    
    # Unpack if needed
    output, _ = pad_packed_sequence(packed_out, batch_first=True)
    return output

Memory Optimization: For large vocabularies, consider using sparse=True in the embedding layer, which uses sparse gradients and reduces memory usage during backpropagation.

Regularization: Apply dropout after the embedding layer to prevent overfitting. Weight decay on embedding parameters can also help, though it’s less common.

Word embeddings are the foundation of modern NLP. Whether training from scratch or using pre-trained weights, PyTorch’s nn.Embedding provides a flexible, efficient interface. Start with pre-trained embeddings for most tasks, and only train custom embeddings when you have domain-specific requirements or sufficient training data.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.