Order-Statistic Tree: Rank and Select Operations
Order-statistic trees solve a deceptively simple problem: given a dynamic collection of elements, how do you efficiently find the k-th smallest element or determine an element's rank? With a sorted...
Key Insights
- Order-statistic trees augment each node with a subtree size, enabling O(log n) rank and select queries that would otherwise require O(n) time with naive approaches
- The key to maintaining correctness is updating size fields during every structural modification—insertions, deletions, and rotations
- Select navigates down using left subtree sizes as pivot points, while rank accumulates counts traversing up from the target node
Order-statistic trees solve a deceptively simple problem: given a dynamic collection of elements, how do you efficiently find the k-th smallest element or determine an element’s rank? With a sorted array, these operations are trivial—but arrays don’t handle insertions and deletions gracefully. Order-statistic trees give you the best of both worlds: O(log n) for all operations.
What Makes an Order-Statistic Tree
An order-statistic tree (OST) is a self-balancing binary search tree—typically a Red-Black tree—where each node stores an additional piece of information: the size of its subtree. This single augmentation unlocks two powerful operations:
- Select(k): Find the k-th smallest element in O(log n)
- Rank(x): Find how many elements are smaller than x in O(log n)
Without augmentation, both operations require O(n) time. You’d need to perform an in-order traversal counting elements until you hit your target. With subtree sizes cached at each node, you can make intelligent decisions at each step, eliminating half the remaining candidates.
Node Augmentation Strategy
The augmentation is straightforward. Each node stores its subtree size—the count of all nodes in its subtree, including itself. Here’s the structure:
class OSTNode:
def __init__(self, key):
self.key = key
self.left = None
self.right = None
self.parent = None
self.size = 1 # Includes this node
# For Red-Black trees, add: self.color = RED
def get_size(node):
"""Safely get size, treating None as 0."""
return node.size if node else 0
def update_size(node):
"""Recalculate size from children. Call after any structural change."""
if node:
node.size = 1 + get_size(node.left) + get_size(node.right)
The get_size helper handles null nodes cleanly—a pattern you’ll use constantly. The invariant is simple: node.size = 1 + left.size + right.size. Every operation that modifies the tree must preserve this invariant.
Select Operation: Finding the K-th Smallest Element
Select answers the question: “What element would be at index k if this tree were a sorted array?” The algorithm exploits a key insight: the left subtree size tells you exactly how many elements are smaller than the current node.
Here’s the logic:
- If k equals the left subtree size plus one, you’ve found your element
- If k is smaller, recurse into the left subtree
- If k is larger, recurse into the right subtree with an adjusted k
def select(node, k):
"""
Find the k-th smallest element (1-indexed).
Returns the node containing that element, or None if k is out of bounds.
"""
if node is None:
return None
# How many nodes are in the left subtree?
left_size = get_size(node.left)
# The current node's rank within this subtree
current_rank = left_size + 1
if k == current_rank:
# Found it: exactly left_size elements are smaller
return node
elif k < current_rank:
# Target is in the left subtree
return select(node.left, k)
else:
# Target is in the right subtree
# Subtract the nodes we're skipping (left subtree + current node)
return select(node.right, k - current_rank)
Walk through an example. Suppose you have a tree with root 15 (size 7), left child 10 (size 3), and right child 20 (size 3). To find the 5th smallest:
- At root 15: left_size = 3, current_rank = 4. Since 5 > 4, go right with k = 5 - 4 = 1
- At node 20: left_size = 1, current_rank = 2. Since 1 < 2, go left
- At node 17: left_size = 0, current_rank = 1. Found it.
The iterative version avoids recursion overhead:
def select_iterative(node, k):
"""Iterative select—often faster in practice."""
while node:
left_size = get_size(node.left)
current_rank = left_size + 1
if k == current_rank:
return node
elif k < current_rank:
node = node.left
else:
k -= current_rank
node = node.right
return None
Rank Operation: Finding an Element’s Position
Rank is the inverse of select: given a key, how many elements are strictly smaller? The algorithm first finds the node (standard BST search), then computes its rank by considering all elements that would appear before it in an in-order traversal.
def rank(root, key):
"""
Return the number of elements strictly less than key.
Returns the rank (0-indexed position in sorted order).
"""
rank_count = 0
node = root
while node:
if key < node.key:
# Go left; current node and right subtree are all larger
node = node.left
elif key > node.key:
# Current node and its left subtree are all smaller
rank_count += get_size(node.left) + 1
node = node.right
else:
# Found the key; add left subtree size and we're done
rank_count += get_size(node.left)
return rank_count
# Key not found; rank_count is where it would be inserted
return rank_count
Notice that this works even if the key doesn’t exist—you get the count of elements that would precede it. This makes OSTs useful for problems like “how many elements are in the range [a, b]?” Just compute rank(b+1) - rank(a).
Here’s an alternative approach when you already have a reference to the node:
def rank_from_node(node):
"""
Compute rank when you have a direct reference to the node.
Traverses up to the root, accumulating counts.
"""
rank_count = get_size(node.left)
while node.parent:
if node == node.parent.right:
# Coming from the right means parent and its left subtree are smaller
rank_count += get_size(node.parent.left) + 1
node = node.parent
return rank_count
Maintaining Augmentation During Tree Modifications
The tricky part isn’t implementing select and rank—it’s keeping the size fields correct as the tree changes. Every insertion, deletion, and rotation must update sizes appropriately.
For insertions, walk up from the newly inserted node and increment sizes:
def insert(root, key):
"""Insert and maintain size invariant."""
new_node = OSTNode(key)
if not root:
return new_node
# Standard BST insertion
parent = None
current = root
while current:
parent = current
current.size += 1 # Increment size as we descend
if key < current.key:
current = current.left
else:
current = current.right
new_node.parent = parent
if key < parent.key:
parent.left = new_node
else:
parent.right = new_node
return root
Rotations require special care. After rotating, the sizes of the two affected nodes change:
def left_rotate(root, x):
"""Left rotation preserving size invariant."""
y = x.right
x.right = y.left
if y.left:
y.left.parent = x
y.parent = x.parent
if not x.parent:
root = y
elif x == x.parent.left:
x.parent.left = y
else:
x.parent.right = y
y.left = x
x.parent = y
# Critical: update sizes in correct order (x first, then y)
y.size = x.size # y takes x's old position
update_size(x) # x's size changes based on new children
return root
def right_rotate(root, y):
"""Right rotation preserving size invariant."""
x = y.left
y.left = x.right
if x.right:
x.right.parent = y
x.parent = y.parent
if not y.parent:
root = x
elif y == y.parent.right:
y.parent.right = x
else:
y.parent.left = x
x.right = y
y.parent = x
# Update sizes: x takes y's old size, y recalculates
x.size = y.size
update_size(y)
return root
Practical Applications and Extensions
The classic application is the running median problem. Maintain two OSTs: one for the lower half, one for the upper half. After each insertion, rebalance by moving the median between trees. Select and rank make this trivial.
class RunningMedian:
def __init__(self):
self.lower = None # Max elements in lower half
self.upper = None # Min elements in upper half
self.lower_count = 0
self.upper_count = 0
def add(self, value):
# Insert into appropriate tree
if self.lower_count == 0 or value <= self._get_lower_max():
self.lower = insert(self.lower, value)
self.lower_count += 1
else:
self.upper = insert(self.upper, value)
self.upper_count += 1
# Rebalance if sizes differ by more than 1
self._rebalance()
def get_median(self):
if self.lower_count > self.upper_count:
return select(self.lower, self.lower_count).key
elif self.upper_count > self.lower_count:
return select(self.upper, 1).key
else:
lower_max = select(self.lower, self.lower_count).key
upper_min = select(self.upper, 1).key
return (lower_max + upper_min) / 2
Range counting becomes a one-liner: rank(b + epsilon) - rank(a) gives you elements in [a, b]. Percentile queries are just select(n * percentile / 100).
Compared to alternatives: Fenwick trees handle cumulative frequency queries but require coordinate compression for arbitrary keys. Segment trees are more flexible but have higher constant factors. OSTs shine when you need dynamic insertions and deletions with order statistics—they’re the natural choice for problems mixing both.
When to Use Order-Statistic Trees
Choose OSTs when you need:
- Dynamic data with frequent insertions and deletions
- Fast k-th element queries or rank computations
- Range counting without preprocessing
Skip them if your data is static (use a sorted array) or if you only need min/max (use a heap). The implementation complexity is moderate—you’re essentially adding one integer per node and a few lines to each rotation. The payoff is substantial: problems that seem to require O(n) per query become O(log n).
For further exploration, look into interval trees (augmented with max endpoint) and persistent data structures (where you can query historical versions). The augmentation pattern—storing derived information at each node—is one of the most powerful techniques in data structure design.