How to Use DataLoader in PyTorch

PyTorch's DataLoader is the bridge between your raw data and your model's training loop. While you could manually iterate through your dataset, batching samples yourself, and implementing shuffling...

Key Insights

  • DataLoader handles batching, shuffling, and parallel data loading automatically, eliminating manual iteration code and significantly improving training performance through multiprocessing.
  • The num_workers parameter is critical for performance—start with 0 for debugging, then set it to 4-8× your CPU cores for production, but watch for diminishing returns and memory overhead.
  • Custom collate_fn functions are essential when working with variable-length sequences or non-standard data structures that don’t fit into uniform tensors.

Introduction to DataLoader and Why It Matters

PyTorch’s DataLoader is the bridge between your raw data and your model’s training loop. While you could manually iterate through your dataset, batching samples yourself, and implementing shuffling logic, DataLoader automates these tasks and adds critical optimizations that directly impact training speed.

DataLoader solves three fundamental problems. First, it handles batching—grouping individual samples into tensors that GPUs can process efficiently. Second, it manages shuffling to prevent your model from learning spurious patterns based on data order. Third, it enables parallel data loading through multiprocessing, ensuring your GPU doesn’t sit idle waiting for the CPU to prepare the next batch.

You need DataLoader for virtually any non-trivial PyTorch project. Skip it only when prototyping with tiny datasets that fit entirely in memory or when you need complete manual control over data iteration for research purposes.

Creating a Basic Dataset and DataLoader

PyTorch requires you to wrap your data in a Dataset class that implements two methods: __len__ returns the total number of samples, and __getitem__ returns a single sample given an index. This abstraction lets DataLoader handle the iteration logic while you focus on data access.

Here’s a custom Dataset for image classification:

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class ImageDataset(Dataset):
    def __init__(self, image_dir, labels, transform=None):
        self.image_dir = image_dir
        self.labels = labels
        self.transform = transform
        self.image_files = sorted(os.listdir(image_dir))
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Create dataset and dataloader
dataset = ImageDataset(
    image_dir='./data/images',
    labels=[0, 1, 0, 1, 1],  # Example labels
)

dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True
)

# Iterate through batches
for images, labels in dataloader:
    print(f"Batch shape: {images.shape}, Labels: {labels.shape}")

The DataLoader takes your Dataset and handles the rest. Each iteration yields a batch of samples, automatically converted to tensors.

Key DataLoader Parameters Explained

Understanding DataLoader parameters is crucial for both correctness and performance.

batch_size determines how many samples are processed together. Larger batches improve GPU utilization but require more memory. Start with 32 or 64, then increase until you hit memory limits. Batch size also affects training dynamics—smaller batches add noise that can help escape local minima but may require learning rate adjustments.

shuffle=True randomizes sample order each epoch, preventing the model from learning based on data sequence. Always shuffle training data. Never shuffle validation or test data—you want consistent evaluation.

num_workers controls parallel data loading. Set to 0 for single-threaded loading (useful for debugging), or use 4-8 workers for production. More workers mean faster data loading but higher memory usage. The optimal number depends on your CPU, disk speed, and preprocessing complexity.

pin_memory=True allocates data in pinned memory, enabling faster CPU-to-GPU transfer. Always use this when training on GPU unless you’re memory-constrained.

# Development configuration - easy debugging
dev_loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,  # Single-threaded for easier debugging
    pin_memory=False
)

# Production configuration - optimized performance
prod_loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,  # Parallel loading
    pin_memory=True,  # Faster GPU transfer
    persistent_workers=True  # Keep workers alive between epochs
)

Working with Built-in PyTorch Datasets

PyTorch provides pre-built datasets through torchvision, torchtext, and torchaudio. These handle downloading, caching, and standard preprocessing, letting you focus on model development.

from torchvision import datasets, transforms

# Define preprocessing pipeline
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    )
])

# Load CIFAR-10
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

# Inspect a batch
images, labels = next(iter(train_loader))
print(f"Images: {images.shape}")  # [128, 3, 32, 32]
print(f"Labels: {labels.shape}")  # [128]

Transforms are applied on-the-fly during data loading, not upfront. This saves memory and enables random augmentations that differ each epoch.

Advanced DataLoader Techniques

Real-world data often doesn’t fit neatly into uniform tensors. Variable-length sequences, nested structures, or imbalanced classes require custom handling.

Custom collate functions control how individual samples combine into batches. This is essential for NLP tasks with variable-length sequences:

from torch.nn.utils.rnn import pad_sequence

class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts  # List of token sequences
        self.labels = labels
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return torch.tensor(self.texts[idx]), self.labels[idx]

def collate_fn(batch):
    # batch is a list of (text, label) tuples
    texts, labels = zip(*batch)
    
    # Pad sequences to same length
    texts_padded = pad_sequence(texts, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)
    
    return texts_padded, labels

# Example usage
texts = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]  # Variable lengths
labels = [0, 1, 0]

dataset = TextDataset(texts, labels)
loader = DataLoader(
    dataset,
    batch_size=2,
    collate_fn=collate_fn
)

for batch_texts, batch_labels in loader:
    print(f"Padded texts shape: {batch_texts.shape}")
    print(f"Labels: {batch_labels}")

drop_last=True discards the final incomplete batch if your dataset size isn’t divisible by batch_size. Use this when batch normalization or other operations require consistent batch sizes.

Common Pitfalls and Best Practices

num_workers debugging: Multiprocessing errors are cryptic. Always test with num_workers=0 first. Once your code works single-threaded, gradually increase workers.

Memory leaks: Workers can leak memory if you’re not careful with transforms or data structures. Monitor memory usage and consider persistent_workers=False if you see growth across epochs.

Reproducibility: Shuffling with multiple workers requires careful seed management:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(42)

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    worker_init_fn=seed_worker,
    generator=g
)

This ensures reproducible shuffling even with parallel workers.

Complete Training Loop Example

Here’s how DataLoader integrates into a full training pipeline:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, 
                               download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, 
                         shuffle=True, num_workers=4, pin_memory=True)

# Model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        # Move data to GPU
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], "
                  f"Step [{batch_idx}/{len(train_loader)}], "
                  f"Loss: {loss.item():.4f}")
    
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")

DataLoader handles all data management while you focus on model logic. The pattern is always the same: iterate through the loader, move data to device, compute loss, backpropagate. This separation of concerns makes PyTorch code clean and maintainable.

Master DataLoader’s parameters and patterns, and you’ll build efficient, scalable training pipelines for any deep learning task.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.