NumPy - np.where() - Conditional Element Selection
import numpy as np
Key Insights
np.where()provides vectorized conditional operations that are 50-100x faster than Python loops for array element selection and replacement- The function operates in two modes: ternary conditional (with three arguments) for element-wise replacement, or index-based (one argument) for finding positions of True values
- Understanding broadcasting rules and return value types is critical—single argument returns tuple of arrays, three arguments returns array matching input shape
Basic Syntax and Return Types
np.where() behaves differently based on the number of arguments provided. With one argument, it returns indices where the condition is True. With three arguments, it performs element-wise selection.
import numpy as np
# Single argument - returns indices
arr = np.array([1, 2, 3, 4, 5])
indices = np.where(arr > 3)
print(indices) # (array([3, 4]),)
print(arr[indices]) # [4 5]
# Three arguments - ternary conditional
result = np.where(arr > 3, arr * 10, arr)
print(result) # [ 1 2 3 40 50]
The single-argument form returns a tuple of arrays—one for each dimension. For 2D arrays, you get two arrays representing row and column indices:
matrix = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
rows, cols = np.where(matrix > 5)
print(rows) # [1 2 2 2]
print(cols) # [2 0 1 2]
print(matrix[rows, cols]) # [6 7 8 9]
Ternary Conditional Operations
The three-argument form np.where(condition, x, y) replaces elements: where condition is True, use x; where False, use y. This is NumPy’s vectorized equivalent of a ternary operator.
temperatures = np.array([15, 22, 8, 30, 18, 5, 25])
# Classify temperatures
classification = np.where(temperatures > 20, 'Hot', 'Cold')
print(classification)
# ['Cold' 'Hot' 'Cold' 'Hot' 'Cold' 'Cold' 'Hot']
# Apply different transformations
celsius = np.array([0, 10, 20, 30, 40])
fahrenheit = np.where(celsius > 25,
celsius * 9/5 + 32, # Convert to F if > 25
celsius) # Keep as C otherwise
print(fahrenheit) # [ 0. 10. 20. 30. 104.]
Both replacement values can be arrays, enabling complex element-wise operations:
prices = np.array([100, 50, 200, 75, 150])
discount_rate = np.array([0.1, 0.2, 0.15, 0.25, 0.1])
premium_rate = np.array([0.05, 0.1, 0.05, 0.1, 0.05])
# Apply different discount rates based on price
final_prices = np.where(prices > 100,
prices * (1 - discount_rate),
prices * (1 - premium_rate))
print(final_prices) # [95. 45. 170. 71.25 142.5]
Nested Conditions and Multiple Criteria
Chain np.where() calls for multiple conditions, or use logical operators to combine conditions in a single call.
scores = np.array([45, 67, 89, 72, 55, 91, 38, 78])
# Nested where for multiple thresholds
grades = np.where(scores >= 90, 'A',
np.where(scores >= 80, 'B',
np.where(scores >= 70, 'C',
np.where(scores >= 60, 'D', 'F'))))
print(grades)
# ['F' 'D' 'B' 'C' 'D' 'A' 'F' 'B']
# Multiple conditions with logical operators
age = np.array([25, 17, 30, 15, 45, 22])
income = np.array([30000, 0, 50000, 0, 80000, 25000])
# Eligible if age >= 18 AND income > 20000
eligible = np.where((age >= 18) & (income > 20000), 'Yes', 'No')
print(eligible) # ['Yes' 'No' 'Yes' 'No' 'Yes' 'Yes']
For complex multi-condition logic, np.select() provides cleaner syntax:
values = np.array([5, 15, 25, 35, 45])
conditions = [
values < 10,
(values >= 10) & (values < 30),
values >= 30
]
choices = ['Low', 'Medium', 'High']
result = np.select(conditions, choices, default='Unknown')
print(result) # ['Low' 'Medium' 'Medium' 'High' 'High']
Performance Optimization Patterns
np.where() dramatically outperforms Python loops. Benchmark comparison:
import time
# Setup
large_array = np.random.randint(0, 100, 1000000)
# Python loop approach
start = time.time()
result_loop = []
for val in large_array:
result_loop.append(val * 2 if val > 50 else val)
loop_time = time.time() - start
# NumPy where approach
start = time.time()
result_where = np.where(large_array > 50, large_array * 2, large_array)
where_time = time.time() - start
print(f"Loop: {loop_time:.4f}s") # ~0.15s
print(f"Where: {where_time:.4f}s") # ~0.002s
print(f"Speedup: {loop_time/where_time:.1f}x") # ~75x
Avoid creating intermediate arrays unnecessarily:
data = np.random.randn(1000000)
# Inefficient - creates intermediate boolean array
mask = data > 0
positive = data[mask]
negative = data[~mask]
# More efficient - direct indexing
positive = data[np.where(data > 0)]
negative = data[np.where(data <= 0)]
# Most efficient for this case - boolean indexing
positive = data[data > 0]
negative = data[data <= 0]
Broadcasting and Shape Handling
np.where() follows NumPy broadcasting rules, enabling operations between arrays of different shapes:
# 1D condition with 2D arrays
matrix = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
row_condition = np.array([True, False, True])
# Broadcasts condition across columns
result = np.where(row_condition[:, np.newaxis], matrix, 0)
print(result)
# [[1 2 3]
# [0 0 0]
# [7 8 9]]
# Scalar replacement values
result = np.where(matrix > 5, 999, -1)
print(result)
# [[-1 -1 -1]
# [-1 -1 999]
# [999 999 999]]
Handle shape mismatches explicitly:
# 3D array operations
cube = np.random.randint(0, 10, (3, 4, 5))
threshold = np.array([5, 6, 7, 8]) # Shape (4,)
# Broadcast threshold across first and third dimensions
threshold_3d = threshold[np.newaxis, :, np.newaxis]
result = np.where(cube > threshold_3d, 1, 0)
print(result.shape) # (3, 4, 5)
Real-World Applications
Data Cleaning: Replace outliers and missing values.
sensor_data = np.array([23.5, 24.1, -999, 23.8, 150.0, 24.2, -999, 23.9])
# Replace error codes and outliers
cleaned = np.where(sensor_data == -999, np.nan, sensor_data)
cleaned = np.where((cleaned < 0) | (cleaned > 100), np.nan, cleaned)
print(cleaned)
# [23.5 24.1 nan 23.8 nan 24.2 nan 23.9]
# Fill with mean of valid values
mean_val = np.nanmean(cleaned)
cleaned = np.where(np.isnan(cleaned), mean_val, cleaned)
Financial Calculations: Apply tiered rates and thresholds.
portfolio_values = np.array([5000, 15000, 50000, 100000, 250000])
# Tiered management fees
fees = np.where(portfolio_values < 10000, portfolio_values * 0.02,
np.where(portfolio_values < 50000, portfolio_values * 0.015,
np.where(portfolio_values < 100000, portfolio_values * 0.01,
portfolio_values * 0.005)))
print(fees) # [100. 225. 500. 1000. 1250.]
Image Processing: Threshold and mask operations.
# Simulate grayscale image
image = np.random.randint(0, 256, (100, 100), dtype=np.uint8)
# Binary threshold
binary = np.where(image > 127, 255, 0)
# Adaptive adjustment
enhanced = np.where(image < 50, image * 1.5,
np.where(image > 200, image * 0.8, image))
enhanced = np.clip(enhanced, 0, 255).astype(np.uint8)
Common Pitfalls
Type Coercion: np.where() upcasts to accommodate both replacement values.
arr = np.array([1, 2, 3, 4, 5])
# Integer array becomes float
result = np.where(arr > 3, arr * 2.5, arr)
print(result.dtype) # float64
# String conversion
result = np.where(arr > 3, 'high', arr)
print(result) # ['1' '2' '3' 'high' 'high']
print(result.dtype) # <U21
Empty Results: Single-argument form returns empty arrays when no matches exist.
arr = np.array([1, 2, 3])
indices = np.where(arr > 10)
print(indices) # (array([], dtype=int64),)
print(len(indices[0])) # 0
Always validate index arrays before using them for indexing operations to avoid silent failures in data pipelines.