How to Create Custom Datasets in PyTorch

PyTorch's `torch.utils.data.Dataset` is an abstract class that serves as the foundation for all dataset implementations. Whether you're loading images, text, audio, or multimodal data, you'll need to...

Key Insights

  • Custom PyTorch datasets require implementing just two methods—__len__() and __getitem__()—but the design choices you make around caching, lazy loading, and transforms significantly impact training performance.
  • Always separate data loading logic from transformation logic by accepting transforms as constructor arguments, enabling you to reuse the same dataset class for training and validation with different augmentation pipelines.
  • Use custom collate functions with DataLoader when working with variable-length sequences or complex data structures that don’t naturally stack into tensors—this is where batch processing gets interesting.

Understanding the Dataset Abstract Class

PyTorch’s torch.utils.data.Dataset is an abstract class that serves as the foundation for all dataset implementations. Whether you’re loading images, text, audio, or multimodal data, you’ll need to understand this interface.

The contract is simple: implement __len__() to return the total number of samples, and __getitem__(idx) to return a single sample at the given index. PyTorch’s DataLoader uses these methods to iterate through your data, handle batching, and enable multiprocessing.

Here’s the minimal skeleton:

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data_source):
        self.data_source = data_source
    
    def __len__(self):
        return len(self.data_source)
    
    def __getitem__(self, idx):
        # Load and return a single sample
        sample = self.data_source[idx]
        return sample

This structure might seem trivial, but it’s powerful. PyTorch doesn’t care how you implement these methods—you could load from disk, query a database, generate data on-the-fly, or pull from an API. The abstraction gives you complete control.

Building a Simple Custom Dataset

Let’s create a practical image classification dataset that loads images from a folder structure where subdirectories represent class labels:

import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        # Build file list
        self.samples = []
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, self.class_to_idx[class_name]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

This implementation builds the complete file list during initialization. For small to medium datasets (under 100K images), this approach works well. The key decision here is when to open and process files—we do it in __getitem__() for memory efficiency.

Handling Data Transformations

Separating transformation logic from data loading is critical for flexibility. You want to use the same dataset class with different augmentation strategies for training versus validation.

from torchvision import transforms

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# Validation transforms without augmentation
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# Same dataset class, different behaviors
train_dataset = ImageFolderDataset('data/train', transform=train_transform)
val_dataset = ImageFolderDataset('data/val', transform=val_transform)

This pattern keeps your dataset class clean and reusable. Never hardcode transforms inside __getitem__()—you’ll regret it when you need different preprocessing for evaluation.

Advanced Dataset Patterns

Real-world datasets often require sophisticated loading strategies. Here’s a dataset with lazy loading and optional caching for large-scale scenarios:

import torch
from torch.utils.data import Dataset
from PIL import Image
import pickle

class LazyLoadDataset(Dataset):
    def __init__(self, data_list, transform=None, cache_size=1000):
        """
        Args:
            data_list: List of (file_path, label) tuples
            transform: Optional transform to apply
            cache_size: Number of samples to keep in memory (0 = no cache)
        """
        self.data_list = data_list
        self.transform = transform
        self.cache_size = cache_size
        self.cache = {}
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        # Check cache first
        if idx in self.cache:
            image, label = self.cache[idx]
        else:
            # Load from disk
            img_path, label = self.data_list[idx]
            image = Image.open(img_path).convert('RGB')
            
            # Cache if not full
            if len(self.cache) < self.cache_size:
                self.cache[idx] = (image.copy(), label)
        
        # Apply transforms to a copy
        if self.transform:
            image = self.transform(image)
        
        return image, label

This implementation demonstrates a simple LRU-style cache. For datasets where loading is expensive (high-resolution images, compressed formats), caching frequently accessed samples can significantly speed up training. Adjust cache_size based on available RAM.

For multimodal data, expand the return tuple:

def __getitem__(self, idx):
    img_path, caption, label = self.data_list[idx]
    
    image = Image.open(img_path).convert('RGB')
    if self.image_transform:
        image = self.image_transform(image)
    
    # Tokenize text
    tokens = self.tokenizer(caption)
    
    return {
        'image': image,
        'text': tokens,
        'label': label
    }

Returning dictionaries instead of tuples improves readability when handling multiple modalities.

Integrating with DataLoader

DataLoader handles batching, shuffling, and multiprocessing. For variable-length sequences, you need a custom collate function:

import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_variable_length(batch):
    """
    Custom collate for variable-length sequences.
    Expects batch to be list of (sequence_tensor, label) tuples.
    """
    sequences, labels = zip(*batch)
    
    # Pad sequences to max length in batch
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)
    
    # Return lengths for packed sequences if needed
    lengths = torch.tensor([len(seq) for seq in sequences])
    
    return sequences_padded, labels, lengths

# Dataset that returns variable-length sequences
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences  # List of variable-length tensors
        self.labels = labels
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]

# DataLoader setup
dataset = SequenceDataset(sequences, labels)
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_variable_length,
    pin_memory=True  # Faster GPU transfer
)

# Usage in training loop
for sequences, labels, lengths in dataloader:
    # sequences is now padded to same length
    # lengths tells you original sequence lengths
    outputs = model(sequences, lengths)

The num_workers parameter enables multiprocessing for data loading. Start with 4 workers and adjust based on your CPU and I/O characteristics. Set pin_memory=True when using CUDA for faster host-to-device transfers.

Testing and Debugging Tips

Before training, validate your dataset thoroughly. Here are essential debugging utilities:

import matplotlib.pyplot as plt
import numpy as np

def inspect_dataset(dataset, num_samples=5):
    """Visualize random samples from dataset."""
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for idx in indices:
        sample = dataset[idx]
        
        if isinstance(sample, tuple):
            image, label = sample
        else:
            image = sample['image']
            label = sample['label']
        
        # Convert tensor to numpy for visualization
        if torch.is_tensor(image):
            image = image.permute(1, 2, 0).numpy()
            # Denormalize if needed
            image = np.clip(image, 0, 1)
        
        print(f"Sample {idx}: Label = {label}, Shape = {image.shape}")
        plt.figure()
        plt.imshow(image)
        plt.title(f"Label: {label}")
        plt.axis('off')
        plt.show()

def validate_dataset_shapes(dataset, expected_shape=None):
    """Check that all samples have consistent shapes."""
    shapes = set()
    
    for i in range(min(100, len(dataset))):  # Check first 100
        sample = dataset[i]
        if isinstance(sample, tuple):
            data = sample[0]
        else:
            data = sample['image']
        
        shapes.add(tuple(data.shape))
    
    print(f"Found {len(shapes)} unique shapes: {shapes}")
    
    if expected_shape and expected_shape not in shapes:
        print(f"WARNING: Expected shape {expected_shape} not found!")
    
    return shapes

# Usage
dataset = ImageFolderDataset('data/train', transform=train_transform)
inspect_dataset(dataset, num_samples=3)
validate_dataset_shapes(dataset, expected_shape=(3, 224, 224))

Common pitfalls to watch for: forgetting to convert PIL images to tensors, incorrect normalization values, file path issues on different operating systems, and memory leaks from unclosed file handles. Always test your dataset with a small subset before launching full training runs.

Custom datasets are the foundation of any PyTorch project. Invest time in getting them right—proper data loading, efficient caching, and clean separation of concerns will pay dividends throughout your model development cycle.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.