How to Implement a GRU in PyTorch
Gated Recurrent Units (GRUs) solve the vanishing gradient problem that plagues vanilla RNNs by introducing gating mechanisms that control information flow. Proposed by Cho et al. in 2014, GRUs are a...
Key Insights
- GRUs use two gates (update and reset) instead of LSTM’s three, making them computationally cheaper while maintaining comparable performance on many sequence tasks
- Implementing a GRU from scratch reveals the elegant mathematics behind gating mechanisms: the update gate controls how much past information to retain, while the reset gate determines how much past state to forget when computing new content
- PyTorch’s built-in
nn.GRUis 5-10x faster than naive implementations due to cuDNN optimization, but understanding the underlying mechanics helps debug training issues and customize architectures
Introduction to GRU Architecture
Gated Recurrent Units (GRUs) solve the vanishing gradient problem that plagues vanilla RNNs by introducing gating mechanisms that control information flow. Proposed by Cho et al. in 2014, GRUs are a streamlined alternative to LSTMs that achieve similar performance with fewer parameters.
The key difference: LSTMs use three gates (input, forget, output) and maintain both a cell state and hidden state, while GRUs use two gates (update and reset) with only a hidden state. This makes GRUs approximately 25% faster to train and easier to tune.
A GRU cell performs three main operations at each time step:
# GRU cell structure (conceptual)
#
# Inputs: x_t (current input), h_{t-1} (previous hidden state)
#
# 1. Reset Gate: r_t = σ(W_r @ x_t + U_r @ h_{t-1} + b_r)
# Controls how much past information to forget
#
# 2. Update Gate: z_t = σ(W_z @ x_t + U_z @ h_{t-1} + b_z)
# Controls how much to update hidden state
#
# 3. Candidate Hidden: h̃_t = tanh(W_h @ x_t + U_h @ (r_t ⊙ h_{t-1}) + b_h)
# Computes new candidate values
#
# 4. Final Hidden: h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h̃_t
# Blends old and new hidden states
#
# Output: h_t (new hidden state)
The reset gate allows the model to drop irrelevant past information, while the update gate decides how much of the new candidate state to use versus keeping the old state. This is more intuitive than LSTM’s separate input and forget gates.
Building a GRU Cell from Scratch
Understanding the mathematics requires implementing it manually. Here’s a complete GRU cell in PyTorch:
import torch
import torch.nn as nn
class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(GRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# Reset gate parameters
self.W_r = nn.Linear(input_size, hidden_size)
self.U_r = nn.Linear(hidden_size, hidden_size)
# Update gate parameters
self.W_z = nn.Linear(input_size, hidden_size)
self.U_z = nn.Linear(hidden_size, hidden_size)
# Candidate hidden state parameters
self.W_h = nn.Linear(input_size, hidden_size)
self.U_h = nn.Linear(hidden_size, hidden_size)
def forward(self, x, h_prev):
"""
Args:
x: input tensor of shape (batch_size, input_size)
h_prev: previous hidden state of shape (batch_size, hidden_size)
Returns:
h: new hidden state of shape (batch_size, hidden_size)
"""
# Reset gate
r = torch.sigmoid(self.W_r(x) + self.U_r(h_prev))
# Update gate
z = torch.sigmoid(self.W_z(x) + self.U_z(h_prev))
# Candidate hidden state (reset gate applied to previous hidden)
h_candidate = torch.tanh(self.W_h(x) + self.U_h(r * h_prev))
# Final hidden state (interpolation between old and new)
h = (1 - z) * h_prev + z * h_candidate
return h
This implementation makes the gating mechanism explicit. The reset gate r modulates how much of the previous hidden state influences the candidate, while the update gate z acts as a learned interpolation coefficient between old and new states.
Creating a Multi-Layer GRU Network
A single GRU cell processes one time step. For sequence processing, we need to iterate through time and optionally stack multiple layers:
class GRUNetwork(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(GRUNetwork, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# Create GRU cells for each layer
self.cells = nn.ModuleList()
for layer in range(num_layers):
layer_input_size = input_size if layer == 0 else hidden_size
self.cells.append(GRUCell(layer_input_size, hidden_size))
# Output projection
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, h_0=None):
"""
Args:
x: input tensor of shape (batch_size, seq_len, input_size)
h_0: initial hidden state of shape (num_layers, batch_size, hidden_size)
Returns:
output: predictions of shape (batch_size, seq_len, output_size)
h_n: final hidden states of shape (num_layers, batch_size, hidden_size)
"""
batch_size, seq_len, _ = x.size()
# Initialize hidden states if not provided
if h_0 is None:
h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size,
device=x.device)
# Process sequence
outputs = []
h_current = h_0
for t in range(seq_len):
x_t = x[:, t, :] # (batch_size, input_size)
h_next = []
# Process through layers
for layer in range(self.num_layers):
h_prev = h_current[layer]
h_new = self.cells[layer](x_t, h_prev)
h_next.append(h_new)
x_t = h_new # Input to next layer
h_current = torch.stack(h_next)
outputs.append(x_t)
# Stack outputs and project
output_seq = torch.stack(outputs, dim=1) # (batch_size, seq_len, hidden_size)
output = self.fc(output_seq)
return output, h_current
This implementation explicitly handles the temporal loop and layer stacking, giving you complete control over the forward pass.
Using PyTorch’s Built-in GRU
For production code, use PyTorch’s optimized nn.GRU:
# Built-in GRU with same configuration
builtin_gru = nn.GRU(
input_size=10,
hidden_size=128,
num_layers=2,
batch_first=True, # Input shape: (batch, seq, feature)
dropout=0.2, # Dropout between layers
bidirectional=False
)
# Example usage
batch_size, seq_len, input_size = 32, 50, 10
x = torch.randn(batch_size, seq_len, input_size)
# Forward pass
output, h_n = builtin_gru(x)
print(f"Output shape: {output.shape}") # (32, 50, 128)
print(f"Final hidden shape: {h_n.shape}") # (2, 32, 128)
# Bidirectional GRU doubles the hidden size
bidir_gru = nn.GRU(input_size=10, hidden_size=128, num_layers=2,
batch_first=True, bidirectional=True)
output_bidir, h_n_bidir = bidir_gru(x)
print(f"Bidirectional output shape: {output_bidir.shape}") # (32, 50, 256)
print(f"Bidirectional hidden shape: {h_n_bidir.shape}") # (4, 32, 128)
The built-in version uses cuDNN when available, providing massive speedups. The bidirectional=True option processes sequences both forward and backward, concatenating the outputs.
Training a GRU on a Real Dataset
Here’s a complete sentiment analysis example using IMDB reviews:
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
class SentimentGRU(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers):
super(SentimentGRU, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.gru = nn.GRU(embedding_dim, hidden_size, num_layers,
batch_first=True, dropout=0.3)
self.fc = nn.Linear(hidden_size, 1)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# x: (batch_size, seq_len)
embedded = self.dropout(self.embedding(x)) # (batch_size, seq_len, embed_dim)
_, h_n = self.gru(embedded) # h_n: (num_layers, batch_size, hidden_size)
# Use final hidden state from last layer
final_hidden = h_n[-1] # (batch_size, hidden_size)
output = self.fc(self.dropout(final_hidden))
return output.squeeze()
# Training setup
vocab_size = 10000
model = SentimentGRU(vocab_size=vocab_size, embedding_dim=128,
hidden_size=256, num_layers=2)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
for batch_texts, batch_labels in dataloader:
batch_texts = batch_texts.to(device)
batch_labels = batch_labels.to(device).float()
optimizer.zero_grad()
outputs = model(batch_texts)
loss = criterion(outputs, batch_labels)
# Gradient clipping (important for RNNs!)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# Example training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Dummy data for demonstration
train_data = TensorDataset(
torch.randint(0, vocab_size, (1000, 200)), # 1000 sequences of length 200
torch.randint(0, 2, (1000,)) # Binary labels
)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
for epoch in range(10):
loss = train_epoch(model, train_loader, criterion, optimizer, device)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
Performance Tips and Best Practices
Always clip gradients when training GRUs. Without clipping, exploding gradients will destabilize training:
# Clip by norm (recommended)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Or clip by value
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
Initialize hidden states properly. For most tasks, zero initialization works fine, but for some applications, learnable initial states help:
class GRUWithLearnableInit(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.h0 = nn.Parameter(torch.zeros(num_layers, 1, hidden_size))
def forward(self, x):
batch_size = x.size(0)
h0 = self.h0.expand(-1, batch_size, -1).contiguous()
return self.gru(x, h0)
GRU vs LSTM decision: Use GRUs when training time matters and your sequences aren’t extremely long (< 500 steps). GRUs train 20-30% faster with 25% fewer parameters. Use LSTMs for very long sequences or when you need the separate cell state for interpretability.
Benchmark custom vs built-in:
import time
custom_model = GRUNetwork(input_size=100, hidden_size=256, num_layers=2, output_size=10)
builtin_model = nn.GRU(input_size=100, hidden_size=256, num_layers=2, batch_first=True)
x = torch.randn(32, 100, 100)
# Custom implementation
start = time.time()
for _ in range(100):
custom_model(x)
custom_time = time.time() - start
# Built-in implementation
start = time.time()
for _ in range(100):
builtin_model(x)
builtin_time = time.time() - start
print(f"Custom: {custom_time:.3f}s, Built-in: {builtin_time:.3f}s")
print(f"Speedup: {custom_time/builtin_time:.1f}x")
On GPU, expect 5-10x speedup from the built-in version. Build custom implementations for learning and experimentation, but use nn.GRU for production.