How to Implement a RNN in PyTorch

Recurrent Neural Networks differ from feedforward networks in one crucial way: they maintain an internal state that gets updated as they process each element in a sequence. This hidden state acts as...

Key Insights

  • RNNs process sequential data by maintaining hidden states that carry information across time steps, making them ideal for tasks like text classification, time series prediction, and language modeling
  • PyTorch provides both low-level primitives to build custom RNN cells and high-level modules (nn.RNN, nn.LSTM, nn.GRU) that handle the complexity of sequential processing with optimized implementations
  • Production RNN implementations require careful handling of variable-length sequences through padding and packing utilities, plus gradient clipping to prevent training instability

Understanding RNN Architecture Fundamentals

Recurrent Neural Networks differ from feedforward networks in one crucial way: they maintain an internal state that gets updated as they process each element in a sequence. This hidden state acts as the network’s memory, allowing it to capture temporal dependencies in data.

At each time step, an RNN cell takes two inputs: the current data point and the previous hidden state. It produces two outputs: a prediction and an updated hidden state that gets passed to the next time step. This architecture makes RNNs naturally suited for sequential data like text, audio, or time series.

Here’s how tensor shapes flow through a basic RNN:

import torch
import torch.nn as nn

# Define dimensions
batch_size = 32
sequence_length = 10
input_size = 50  # e.g., word embedding dimension
hidden_size = 128

# Input tensor: (batch, sequence, features)
x = torch.randn(batch_size, sequence_length, input_size)

# Initial hidden state: (batch, hidden_size)
h0 = torch.zeros(batch_size, hidden_size)

print(f"Input shape: {x.shape}")
print(f"Hidden state shape: {h0.shape}")

# After processing one time step
# Output shape: (batch, hidden_size)
# New hidden state shape: (batch, hidden_size)

Building a Basic RNN from Scratch

Before using PyTorch’s built-in modules, let’s implement a vanilla RNN cell to understand the mechanics. The core computation is straightforward: multiply the input by a weight matrix, multiply the previous hidden state by another weight matrix, add them together, and apply an activation function.

class SimpleRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Weight matrix for input
        self.W_ih = nn.Linear(input_size, hidden_size)
        # Weight matrix for hidden state
        self.W_hh = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, x, hidden):
        """
        x: input at current time step (batch, input_size)
        hidden: previous hidden state (batch, hidden_size)
        """
        # h_t = tanh(W_ih * x_t + W_hh * h_{t-1})
        new_hidden = torch.tanh(
            self.W_ih(x) + self.W_hh(hidden)
        )
        return new_hidden

# Process a sequence
rnn_cell = SimpleRNNCell(input_size=50, hidden_size=128)
batch_size = 32
sequence_length = 10

# Initialize hidden state
hidden = torch.zeros(batch_size, 128)

# Process each time step
for t in range(sequence_length):
    x_t = torch.randn(batch_size, 50)
    hidden = rnn_cell(x_t, hidden)
    
print(f"Final hidden state shape: {hidden.shape}")

This implementation reveals the fundamental issue with vanilla RNNs: gradients must flow backward through many matrix multiplications, leading to vanishing or exploding gradients. This is why LSTM and GRU variants were developed.

Using PyTorch’s Built-in RNN Modules

PyTorch provides optimized implementations that handle the sequential processing for you. Here’s how to use the three main variants:

import torch.nn as nn

# Vanilla RNN
rnn = nn.RNN(
    input_size=50,
    hidden_size=128,
    num_layers=2,
    batch_first=True,  # Input shape: (batch, seq, features)
    dropout=0.2
)

# LSTM (Long Short-Term Memory)
lstm = nn.LSTM(
    input_size=50,
    hidden_size=128,
    num_layers=2,
    batch_first=True,
    dropout=0.2
)

# GRU (Gated Recurrent Unit)
gru = nn.GRU(
    input_size=50,
    hidden_size=128,
    num_layers=2,
    batch_first=True,
    dropout=0.2
)

# Usage example
x = torch.randn(32, 10, 50)  # (batch, sequence, features)

# RNN returns output and final hidden state
rnn_out, rnn_hidden = rnn(x)
print(f"RNN output: {rnn_out.shape}")  # (32, 10, 128)

# LSTM returns output, (final hidden, final cell state)
lstm_out, (lstm_hidden, lstm_cell) = lstm(x)
print(f"LSTM output: {lstm_out.shape}")  # (32, 10, 128)

# GRU returns output and final hidden state
gru_out, gru_hidden = gru(x)
print(f"GRU output: {gru_out.shape}")  # (32, 10, 128)

For most applications, start with LSTM or GRU. LSTMs handle long-term dependencies better but have more parameters. GRUs are faster and often perform comparably.

Complete Implementation: Sentiment Classification

Let’s build a complete sentiment classifier for movie reviews. This example demonstrates the full pipeline from data to predictions.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

class SentimentRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            dropout=0.3
        )
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, text):
        # text: (batch, seq_len)
        embedded = self.dropout(self.embedding(text))
        # embedded: (batch, seq_len, embedding_dim)
        
        lstm_out, (hidden, cell) = self.lstm(embedded)
        # Use the final hidden state for classification
        # hidden: (n_layers, batch, hidden_dim)
        
        # Take the last layer's hidden state
        final_hidden = hidden[-1]  # (batch, hidden_dim)
        output = self.fc(self.dropout(final_hidden))
        return output

# Training function
def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    
    for batch in train_loader:
        texts, labels = batch
        texts, labels = texts.to(device), labels.to(device)
        
        optimizer.zero_grad()
        predictions = model(texts)
        loss = criterion(predictions, labels)
        loss.backward()
        
        # Gradient clipping (important for RNNs!)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        epoch_loss += loss.item()
    
    return epoch_loss / len(train_loader)

# Initialize model
vocab_size = 10000
model = SentimentRNN(
    vocab_size=vocab_size,
    embedding_dim=100,
    hidden_dim=256,
    output_dim=2,  # Binary classification
    n_layers=2
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop (assuming train_loader is defined)
# for epoch in range(10):
#     train_loss = train_model(model, train_loader, criterion, optimizer, device)
#     print(f'Epoch {epoch+1}, Loss: {train_loss:.4f}')

Handling Variable-Length Sequences

Real-world data rarely comes in uniform lengths. PyTorch provides utilities to efficiently batch variable-length sequences:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class ImprovedSentimentRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, text, text_lengths):
        # text: (batch, padded_seq_len)
        # text_lengths: (batch,) - actual length of each sequence
        
        embedded = self.embedding(text)
        
        # Pack the padded sequence
        packed_embedded = pack_padded_sequence(
            embedded, 
            text_lengths.cpu(), 
            batch_first=True, 
            enforce_sorted=False
        )
        
        # Pass through LSTM
        packed_output, (hidden, cell) = self.lstm(packed_embedded)
        
        # Unpack if you need the outputs
        # output, output_lengths = pad_packed_sequence(packed_output, batch_first=True)
        
        # Use final hidden state for classification
        output = self.fc(hidden[-1])
        return output

# Example usage
texts = torch.tensor([[1, 2, 3, 4, 0, 0], [5, 6, 0, 0, 0, 0]])  # Padded
lengths = torch.tensor([4, 2])  # Actual lengths

model = ImprovedSentimentRNN(vocab_size=100, embedding_dim=50, hidden_dim=128, output_dim=2)
output = model(texts, lengths)

Packing eliminates unnecessary computation on padding tokens, significantly speeding up training on datasets with variable-length sequences.

Best Practices and Common Pitfalls

Always clip gradients. RNNs are notorious for exploding gradients. Monitor gradient norms during training:

def train_with_monitoring(model, data_loader, optimizer, criterion):
    model.train()
    total_norm = 0
    
    for batch in data_loader:
        optimizer.zero_grad()
        outputs = model(batch['input'])
        loss = criterion(outputs, batch['target'])
        loss.backward()
        
        # Calculate gradient norm before clipping
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        total_norm += grad_norm.item()
        
        optimizer.step()
    
    avg_grad_norm = total_norm / len(data_loader)
    print(f"Average gradient norm: {avg_grad_norm:.4f}")

Use bidirectional RNNs for non-causal tasks. When you have access to the entire sequence (not generating text in real-time), bidirectional RNNs capture both past and future context:

lstm = nn.LSTM(
    input_size=100,
    hidden_size=128,
    num_layers=2,
    batch_first=True,
    bidirectional=True  # Output size doubles: 256
)

Initialize forget gate biases to 1 for LSTMs. This helps information flow in early training:

for name, param in lstm.named_parameters():
    if 'bias' in name:
        n = param.size(0)
        param.data[n//4:n//2].fill_(1.0)  # Forget gate biases

RNNs remain powerful tools for sequential data despite the rise of Transformers. They’re more memory-efficient for long sequences and perfectly adequate for many production applications. Master these fundamentals, and you’ll have a versatile tool for any sequential modeling task.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.