How to Split Arrays in NumPy
Array splitting is one of those operations you'll reach for constantly once you know it exists. Whether you're preparing data for machine learning, processing large datasets in manageable chunks, or...
Key Insights
- NumPy provides multiple splitting functions (
split,array_split,hsplit,vsplit,dsplit) each optimized for different use cases—choosing the right one simplifies your code and prevents errors. - Use
array_split()instead ofsplit()when your array length isn’t evenly divisible; it handles remainders gracefully rather than raising an exception. - Splitting at specific indices gives you precise control over chunk boundaries, which is essential for creating train/test splits and processing data in custom batch sizes.
Introduction
Array splitting is one of those operations you’ll reach for constantly once you know it exists. Whether you’re preparing data for machine learning, processing large datasets in manageable chunks, or distributing work across parallel processes, breaking arrays into smaller pieces is fundamental.
NumPy provides a family of splitting functions that handle everything from simple equal divisions to complex multi-dimensional slicing. The problem is that most developers only know about split() and end up writing verbose workarounds when they hit its limitations. This article covers the complete toolkit so you can pick the right function for each situation.
Basic Array Splitting with numpy.split()
The split() function divides an array into equal parts along a specified axis. It’s the simplest splitting operation and works well when you know your array divides evenly.
import numpy as np
# Create a 1D array with 12 elements
data = np.arange(12)
print(f"Original array: {data}")
# Split into 3 equal parts
chunks = np.split(data, 3)
for i, chunk in enumerate(chunks):
print(f"Chunk {i}: {chunk}")
Output:
Original array: [ 0 1 2 3 4 5 6 7 8 9 10 11]
Chunk 0: [0 1 2 3]
Chunk 1: [4 5 6 7]
Chunk 2: [ 8 9 10 11]
The function returns a list of arrays, not a single array. This is important—you can iterate over the result directly or access chunks by index.
Here’s the catch: split() raises a ValueError if the array doesn’t divide evenly:
# This will fail
data = np.arange(10)
try:
chunks = np.split(data, 3)
except ValueError as e:
print(f"Error: {e}")
Output:
Error: array split does not result in an equal division
This strictness is sometimes what you want—it catches assumptions about your data that don’t hold. But when you need flexibility, reach for array_split().
Splitting into Unequal Parts with numpy.array_split()
The array_split() function handles arrays that don’t divide evenly by distributing remainder elements across the first chunks. It never raises an exception for uneven divisions.
# Split 10 elements into 3 chunks
data = np.arange(10)
print(f"Original array: {data}")
chunks = np.array_split(data, 3)
for i, chunk in enumerate(chunks):
print(f"Chunk {i} (length {len(chunk)}): {chunk}")
Output:
Original array: [0 1 2 3 4 5 6 7 8 9]
Chunk 0 (length 4): [0 1 2 3]
Chunk 1 (length 3): [4 5 6]
Chunk 2 (length 3): [7 8 9]
Notice how the first chunk gets 4 elements while the others get 3. The algorithm distributes the remainder (10 mod 3 = 1) to the earlier chunks. With 11 elements split into 3 parts, you’d get chunks of 4, 4, and 3.
This behavior makes array_split() ideal for batch processing where you want roughly equal batches without padding:
def process_in_batches(data, batch_size):
"""Process data in batches of approximately batch_size."""
n_batches = int(np.ceil(len(data) / batch_size))
batches = np.array_split(data, n_batches)
results = []
for batch in batches:
# Your processing logic here
results.append(batch.mean())
return results
data = np.random.randn(1000)
batch_means = process_in_batches(data, batch_size=64)
print(f"Processed {len(batch_means)} batches")
Horizontal and Vertical Splitting (hsplit, vsplit)
For 2D arrays, NumPy provides convenience functions that make your intent clearer than specifying axis numbers.
hsplit() splits horizontally (along columns):
# Create a 4x6 matrix
matrix = np.arange(24).reshape(4, 6)
print("Original matrix:")
print(matrix)
# Split into 3 parts horizontally (by columns)
left, middle, right = np.hsplit(matrix, 3)
print("\nLeft section:")
print(left)
print("\nMiddle section:")
print(middle)
print("\nRight section:")
print(right)
Output:
Original matrix:
[[ 0 1 2 3 4 5]
[ 6 7 8 9 10 11]
[12 13 14 15 16 17]
[18 19 20 21 22 23]]
Left section:
[[ 0 1]
[ 6 7]
[12 13]
[18 19]]
Middle section:
[[ 2 3]
[ 8 9]
[14 15]
[20 21]]
Right section:
[[ 4 5]
[10 11]
[16 17]
[22 23]]
vsplit() splits vertically (along rows):
# Split into 2 parts vertically (by rows)
top, bottom = np.vsplit(matrix, 2)
print("Top section:")
print(top)
print("\nBottom section:")
print(bottom)
Output:
Top section:
[[ 0 1 2 3 4 5]
[ 6 7 8 9 10 11]]
Bottom section:
[[12 13 14 15 16 17]
[18 19 20 21 22 23]]
These functions are equivalent to calling split() with axis=1 and axis=0 respectively, but the names communicate intent better. Use them when working with tabular data, images, or any 2D structure where “horizontal” and “vertical” have meaning.
Splitting Along Arbitrary Axes with dsplit()
When working with 3D arrays—common in image processing, video data, and scientific computing—dsplit() splits along the third axis (depth). This is particularly useful for separating color channels in RGB images.
# Create a 4x4 "image" with 3 color channels (RGB)
image = np.arange(48).reshape(4, 4, 3)
print(f"Image shape: {image.shape}")
# Split into individual color channels
red, green, blue = np.dsplit(image, 3)
print(f"\nRed channel shape: {red.shape}")
print("Red channel values:")
print(red.squeeze()) # Remove the singleton dimension for display
Output:
Image shape: (4, 4, 3)
Red channel shape: (4, 4, 1)
Red channel values:
[[ 0 3 6 9]
[12 15 18 21]
[24 27 30 33]
[36 39 42 45]]
Note that dsplit() preserves the third dimension (shape becomes (4, 4, 1) rather than (4, 4)). Use squeeze() if you need to remove it.
For more complex 3D splitting scenarios, use the general split() function with an explicit axis:
# Split a 3D array along the first axis
volume = np.arange(24).reshape(4, 3, 2)
slices = np.split(volume, 2, axis=0)
print(f"Each slice shape: {slices[0].shape}")
Splitting at Specific Indices
All NumPy splitting functions accept either a number of sections or an array of indices. Index-based splitting gives you precise control over where cuts happen.
data = np.arange(15)
print(f"Original: {data}")
# Split at indices 2, 5, and 8
# This creates 4 chunks: [0:2], [2:5], [5:8], [8:]
chunks = np.split(data, [2, 5, 8])
for i, chunk in enumerate(chunks):
print(f"Chunk {i}: {chunk}")
Output:
Original: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]
Chunk 0: [0 1]
Chunk 1: [2 3 4]
Chunk 2: [5 6 7]
Chunk 3: [ 8 9 10 11 12 13 14]
The indices specify where to cut, not the chunk sizes. You always get len(indices) + 1 chunks. This works with all the splitting functions:
matrix = np.arange(20).reshape(4, 5)
print("Original matrix:")
print(matrix)
# Split columns at indices 1 and 3
parts = np.hsplit(matrix, [1, 3])
print(f"\nNumber of parts: {len(parts)}")
for i, part in enumerate(parts):
print(f"Part {i}:\n{part}")
Practical Use Cases
Let’s look at real-world applications that combine these techniques.
Train/Validation/Test Split
def train_val_test_split(X, y, train_ratio=0.7, val_ratio=0.15):
"""Split data into train, validation, and test sets."""
n_samples = len(X)
# Calculate split indices
train_end = int(n_samples * train_ratio)
val_end = int(n_samples * (train_ratio + val_ratio))
# Shuffle indices
indices = np.random.permutation(n_samples)
# Split indices
train_idx, val_idx, test_idx = np.split(indices, [train_end, val_end])
return (X[train_idx], y[train_idx],
X[val_idx], y[val_idx],
X[test_idx], y[test_idx])
# Example usage
X = np.random.randn(1000, 10)
y = np.random.randint(0, 2, 1000)
X_train, y_train, X_val, y_val, X_test, y_test = train_val_test_split(X, y)
print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
Chunked File Processing
def process_large_array(filepath, chunk_size=10000):
"""Process a large array in memory-efficient chunks."""
data = np.load(filepath)
n_chunks = int(np.ceil(len(data) / chunk_size))
chunks = np.array_split(data, n_chunks)
results = []
for i, chunk in enumerate(chunks):
# Process each chunk
result = expensive_computation(chunk)
results.append(result)
print(f"Processed chunk {i+1}/{n_chunks}")
return np.concatenate(results)
K-Fold Cross Validation Indices
def k_fold_indices(n_samples, k=5):
"""Generate train/test indices for k-fold cross validation."""
indices = np.arange(n_samples)
np.random.shuffle(indices)
folds = np.array_split(indices, k)
for i in range(k):
test_idx = folds[i]
train_idx = np.concatenate([folds[j] for j in range(k) if j != i])
yield train_idx, test_idx
# Example usage
for fold, (train_idx, test_idx) in enumerate(k_fold_indices(100, k=5)):
print(f"Fold {fold}: {len(train_idx)} train, {len(test_idx)} test")
The key to using NumPy’s splitting functions effectively is matching the function to your requirements: use split() when equal divisions are mandatory, array_split() when you need flexibility, and the axis-specific variants when working with multi-dimensional data. Index-based splitting handles everything else.