NumPy - np.sum() with axis Parameter

• The `axis` parameter in `np.sum()` determines the dimension along which summation occurs, with `axis=0` summing down columns, `axis=1` summing across rows, and `axis=None` (default) summing all...

Key Insights

• The axis parameter in np.sum() determines the dimension along which summation occurs, with axis=0 summing down columns, axis=1 summing across rows, and axis=None (default) summing all elements • Understanding axis operations requires visualizing arrays as nested structures where axis=0 operates on the outermost level, axis=1 on the next level inward, and so on for higher dimensions • Combining axis with keepdims=True preserves array dimensionality after reduction, critical for broadcasting operations in machine learning pipelines

Understanding the Axis Parameter Fundamentals

The axis parameter in NumPy’s sum() function controls which dimension collapses during the summation operation. When you sum along an axis, that dimension disappears from the resulting array shape.

import numpy as np

# 2D array example
arr = np.array([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])

print(f"Original shape: {arr.shape}")  # (3, 3)

# Sum along axis 0 (down the columns)
sum_axis_0 = np.sum(arr, axis=0)
print(f"axis=0 result: {sum_axis_0}")  # [12 15 18]
print(f"axis=0 shape: {sum_axis_0.shape}")  # (3,)

# Sum along axis 1 (across the rows)
sum_axis_1 = np.sum(arr, axis=1)
print(f"axis=1 result: {sum_axis_1}")  # [ 6 15 24]
print(f"axis=1 shape: {sum_axis_1.shape}")  # (3,)

# Sum all elements (default)
sum_all = np.sum(arr)
print(f"Total sum: {sum_all}")  # 45

The mental model: axis=0 moves vertically down rows, collapsing the row dimension. axis=1 moves horizontally across columns, collapsing the column dimension.

Working with 3D Arrays

Three-dimensional arrays introduce axis=2, which operates on the innermost dimension. Visualize 3D arrays as stacks of 2D matrices.

# Create a 3D array: 2 matrices of 3x4
arr_3d = np.array([[[1, 2, 3, 4],
                     [5, 6, 7, 8],
                     [9, 10, 11, 12]],
                    
                    [[13, 14, 15, 16],
                     [17, 18, 19, 20],
                     [21, 22, 23, 24]]])

print(f"3D array shape: {arr_3d.shape}")  # (2, 3, 4)

# axis=0: Sum across the depth (combine the two matrices)
sum_axis_0 = np.sum(arr_3d, axis=0)
print(f"axis=0 shape: {sum_axis_0.shape}")  # (3, 4)
print(f"axis=0 result:\n{sum_axis_0}")
# [[14 16 18 20]
#  [22 24 26 28]
#  [30 32 34 36]]

# axis=1: Sum down rows within each matrix
sum_axis_1 = np.sum(arr_3d, axis=1)
print(f"axis=1 shape: {sum_axis_1.shape}")  # (2, 4)
print(f"axis=1 result:\n{sum_axis_1}")
# [[15 18 21 24]
#  [51 54 57 60]]

# axis=2: Sum across columns within each row
sum_axis_2 = np.sum(arr_3d, axis=2)
print(f"axis=2 shape: {sum_axis_2.shape}")  # (2, 3)
print(f"axis=2 result:\n{sum_axis_2}")
# [[10 26 42]
#  [58 74 90]]

Each axis operation reduces dimensionality by one. The axis you specify is the dimension that disappears.

Multiple Axes and Tuple Parameters

You can sum along multiple axes simultaneously by passing a tuple to the axis parameter.

arr_3d = np.arange(24).reshape(2, 3, 4)

# Sum along axes 0 and 1 (leaves only axis 2)
sum_axes_01 = np.sum(arr_3d, axis=(0, 1))
print(f"axis=(0,1) shape: {sum_axes_01.shape}")  # (4,)
print(f"Result: {sum_axes_01}")  # [60 66 72 78]

# Sum along axes 1 and 2 (leaves only axis 0)
sum_axes_12 = np.sum(arr_3d, axis=(1, 2))
print(f"axis=(1,2) shape: {sum_axes_12.shape}")  # (2,)
print(f"Result: {sum_axes_12}")  # [ 66 210]

# Equivalent to summing all elements
sum_all_axes = np.sum(arr_3d, axis=(0, 1, 2))
print(f"All axes sum: {sum_all_axes}")  # 276

This technique proves invaluable when working with batched data in deep learning, where you need to aggregate across specific dimensions while preserving others.

The keepdims Parameter

Setting keepdims=True maintains the original number of dimensions by keeping reduced axes as size-1 dimensions. This enables proper broadcasting in subsequent operations.

arr = np.array([[1, 2, 3],
                [4, 5, 6]])

# Without keepdims
sum_normal = np.sum(arr, axis=1)
print(f"Normal shape: {sum_normal.shape}")  # (2,)
print(f"Normal result: {sum_normal}")  # [ 6 15]

# With keepdims
sum_keepdims = np.sum(arr, axis=1, keepdims=True)
print(f"keepdims shape: {sum_keepdims.shape}")  # (2, 1)
print(f"keepdims result:\n{sum_keepdims}")
# [[ 6]
#  [15]]

# Broadcasting example: normalize by row sums
normalized = arr / sum_keepdims
print(f"Normalized array:\n{normalized}")
# [[0.16666667 0.33333333 0.5       ]
#  [0.26666667 0.33333333 0.4       ]]

Without keepdims, you’d need to reshape the result manually for broadcasting compatibility.

Practical Application: Batch Processing

Real-world scenario: computing statistics across batches of images represented as 4D arrays (batch_size, height, width, channels).

# Simulate batch of 10 RGB images (32x32 pixels)
images = np.random.randint(0, 256, size=(10, 32, 32, 3), dtype=np.uint8)

# Compute mean pixel value per image (across spatial dimensions and channels)
mean_per_image = np.sum(images, axis=(1, 2, 3)) / (32 * 32 * 3)
print(f"Mean per image shape: {mean_per_image.shape}")  # (10,)

# Compute sum per channel across entire batch
sum_per_channel = np.sum(images, axis=(0, 1, 2))
print(f"Sum per channel: {sum_per_channel}")  # Shape: (3,)

# Compute spatial sum for each image, keeping channel dimension
spatial_sum = np.sum(images, axis=(1, 2), keepdims=True)
print(f"Spatial sum shape: {spatial_sum.shape}")  # (10, 1, 1, 3)

# Now we can normalize each image by its spatial sum
normalized_images = images / spatial_sum
print(f"Normalized shape: {normalized_images.shape}")  # (10, 32, 32, 3)

Performance Considerations and Data Types

The dtype parameter controls output precision. NumPy uses platform-dependent integer types by default, which can cause overflow.

# Integer overflow example
small_arr = np.array([100, 100, 100], dtype=np.int8)
result_overflow = np.sum(small_arr)
print(f"Overflow result: {result_overflow}")  # 44 (not 300!)

# Specify dtype to prevent overflow
result_correct = np.sum(small_arr, dtype=np.int64)
print(f"Correct result: {result_correct}")  # 300

# Floating point for precision
float_arr = np.array([1e10, 1.0, -1e10])
result_float32 = np.sum(float_arr, dtype=np.float32)
result_float64 = np.sum(float_arr, dtype=np.float64)
print(f"float32: {result_float32}")  # 0.0 (precision loss)
print(f"float64: {result_float64}")  # 1.0 (correct)

For large arrays, specify dtype explicitly to balance memory usage and numerical stability.

Advanced Pattern: Conditional Summation

Combine boolean indexing with axis-based summation for conditional aggregations.

data = np.array([[10, -5, 3],
                 [-2, 8, -1],
                 [7, -3, 9]])

# Sum only positive values along each axis
positive_mask = data > 0
sum_positive_axis0 = np.sum(data * positive_mask, axis=0)
print(f"Positive sum axis=0: {sum_positive_axis0}")  # [17  8 12]

# Count positive values per row
count_positive_per_row = np.sum(positive_mask, axis=1)
print(f"Positive count per row: {count_positive_per_row}")  # [2 1 2]

# Weighted sum using axis parameter
weights = np.array([0.5, 0.3, 0.2])
weighted_sum = np.sum(data * weights, axis=1)
print(f"Weighted sum: {weighted_sum}")  # [ 4.1  0.1  3.1]

Negative Axis Indexing

Negative indices count from the last axis backward, useful for dimension-agnostic code.

arr = np.random.rand(5, 4, 3, 2)

# axis=-1 is equivalent to axis=3 (last axis)
sum_last = np.sum(arr, axis=-1)
print(f"axis=-1 shape: {sum_last.shape}")  # (5, 4, 3)

# axis=-2 is equivalent to axis=2 (second to last)
sum_second_last = np.sum(arr, axis=-2)
print(f"axis=-2 shape: {sum_second_last.shape}")  # (5, 4, 2)

# Works with tuples too
sum_last_two = np.sum(arr, axis=(-2, -1))
print(f"axis=(-2,-1) shape: {sum_last_two.shape}")  # (5, 4)

This approach makes functions more flexible when working with variable-dimension inputs.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.