Wavelet Tree: Rank and Select Queries

Wavelet trees solve a deceptively simple problem: given a string over an alphabet of σ symbols, answer rank and select queries efficiently. These operations form the backbone of modern compressed...

Key Insights

  • Wavelet trees transform rank and select queries on arbitrary alphabets into O(log σ) bitvector operations, making them foundational for compressed text indexes like the FM-index.
  • The structure recursively partitions the alphabet at each level, storing only a single bit per character, achieving near-optimal space of n log σ bits while maintaining fast query times.
  • Rank queries traverse top-down using bitvector rank to narrow position ranges, while select queries work bottom-up, inverting the process through bitvector select operations.

Introduction to Wavelet Trees

Wavelet trees solve a deceptively simple problem: given a string over an alphabet of σ symbols, answer rank and select queries efficiently. These operations form the backbone of modern compressed text indexes, enabling substring search in space proportional to the compressed text size.

The naive approach scans the string linearly, yielding O(n) query time. For a bitvector (σ = 2), we can achieve O(1) time with auxiliary structures. Wavelet trees extend this efficiency to arbitrary alphabets, reducing queries to O(log σ) bitvector operations.

This matters because text indexing structures like the FM-index require millions of rank queries during pattern matching. Shaving complexity from O(n) to O(log σ) transforms impractical algorithms into production-ready systems.

Understanding Rank and Select Operations

Two operations define the interface:

Rank(c, i): Returns the count of character c in positions 0 through i-1. For string “abracadabra”, Rank(‘a’, 7) = 3 because positions 0, 3, and 5 contain ‘a’.

Select(c, k): Returns the position of the k-th occurrence of character c (1-indexed). Select(‘a’, 3) = 5 for the same string.

Here’s the baseline implementation we’re trying to beat:

def naive_rank(text: str, c: str, i: int) -> int:
    """O(n) rank: count occurrences of c in text[0:i]"""
    return sum(1 for j in range(i) if text[j] == c)

def naive_select(text: str, c: str, k: int) -> int:
    """O(n) select: find position of k-th occurrence of c"""
    count = 0
    for i, char in enumerate(text):
        if char == c:
            count += 1
            if count == k:
                return i
    return -1  # Not found

These O(n) implementations become bottlenecks when queries number in the millions. For FM-index backward search, each step requires one rank query per pattern character. Searching a 1GB text for a 100-character pattern would require billions of character comparisons with the naive approach.

Wavelet Tree Structure

A wavelet tree recursively partitions the alphabet. At the root, we split characters into two halves: those in the “left” half of the alphabet and those in the “right” half. We store a bitvector where 0 indicates left-half characters and 1 indicates right-half characters.

Each child node repeats this process on its subset of the alphabet until leaves represent individual characters.

Consider the string “abracadabra” over alphabet {a, b, c, d, r}. We assign bit patterns based on lexicographic position:

  • a = 000, b = 001, c = 010, d = 011, r = 100

The root bitvector distinguishes {a, b, c, d} (bit 0) from {r} (bit 1):

String:    a b r a c a d a b r a
Root BV:   0 0 1 0 0 0 0 0 0 1 0

Characters with 0 bits descend to the left child; those with 1 bits go right. The left child receives “abacadaba” and further partitions {a, b} from {c, d}.

from dataclasses import dataclass
from typing import Optional, List

@dataclass
class WaveletNode:
    bitvector: List[int]
    left: Optional['WaveletNode'] = None
    right: Optional['WaveletNode'] = None
    # Alphabet range this node covers
    alpha_lo: int = 0
    alpha_hi: int = 0

class WaveletTree:
    def __init__(self, text: str):
        self.alphabet = sorted(set(text))
        self.char_to_idx = {c: i for i, c in enumerate(self.alphabet)}
        self.root = self._build(text, 0, len(self.alphabet) - 1)
    
    def _build(self, text: str, lo: int, hi: int) -> Optional[WaveletNode]:
        if lo > hi or not text:
            return None
        if lo == hi:
            # Leaf node - no bitvector needed
            return WaveletNode(bitvector=[], alpha_lo=lo, alpha_hi=hi)
        
        mid = (lo + hi) // 2
        bitvector = []
        left_text = []
        right_text = []
        
        for c in text:
            idx = self.char_to_idx[c]
            if idx <= mid:
                bitvector.append(0)
                left_text.append(c)
            else:
                bitvector.append(1)
                right_text.append(c)
        
        node = WaveletNode(
            bitvector=bitvector,
            alpha_lo=lo,
            alpha_hi=hi
        )
        node.left = self._build(''.join(left_text), lo, mid)
        node.right = self._build(''.join(right_text), mid + 1, hi)
        return node

The tree depth equals ⌈log₂ σ⌉. Each character contributes one bit per level, yielding n log σ total bits for the bitvectors—matching the information-theoretic minimum for storing a string over alphabet σ.

Implementing Rank Queries

Rank queries traverse from root to leaf, using bitvector rank to track positions. At each node, we determine whether the target character belongs to the left or right subtree, then compute how many characters went that direction before our position.

def bitvector_rank(bv: List[int], bit: int, i: int) -> int:
    """Count occurrences of 'bit' in bv[0:i]"""
    return sum(1 for j in range(i) if bv[j] == bit)

def rank(self, c: str, i: int) -> int:
    """Count occurrences of c in text[0:i]"""
    if c not in self.char_to_idx:
        return 0
    
    target_idx = self.char_to_idx[c]
    node = self.root
    pos = i
    
    while node and node.left is not None:  # Not a leaf
        mid = (node.alpha_lo + node.alpha_hi) // 2
        
        if target_idx <= mid:
            # Character is in left subtree
            # Count 0s before position pos
            pos = bitvector_rank(node.bitvector, 0, pos)
            node = node.left
        else:
            # Character is in right subtree
            # Count 1s before position pos
            pos = bitvector_rank(node.bitvector, 1, pos)
            node = node.right
    
    return pos

Let’s trace Rank(‘a’, 7) on “abracadabra”:

  1. Root: target ‘a’ (idx=0), mid=2. Since 0 ≤ 2, go left. Count 0s in first 7 positions of [0,0,1,0,0,0,0,0,0,1,0] = 6. New pos = 6.

  2. Left child (covers {a,b,c,d}): mid=1. Since 0 ≤ 1, go left. The left child’s bitvector represents “abacadaba”. Count 0s in first 6 positions = 4. New pos = 4.

  3. Left-left child (covers {a,b}): mid=0. Since 0 ≤ 0, go left. Bitvector represents “aacaaa”. Count 0s in first 4 positions = 3.

  4. Leaf: Return 3.

Each level requires one bitvector rank operation. With O(1) bitvector rank (covered later), total complexity is O(log σ).

Implementing Select Queries

Select inverts the rank process, traversing bottom-up from the leaf. We first locate the leaf for character c, then backtrack through ancestors, using bitvector select to map positions back up.

def bitvector_select(bv: List[int], bit: int, k: int) -> int:
    """Find position of k-th occurrence of 'bit' (1-indexed)"""
    count = 0
    for i, b in enumerate(bv):
        if b == bit:
            count += 1
            if count == k:
                return i
    return -1

def select(self, c: str, k: int) -> int:
    """Find position of k-th occurrence of c (1-indexed)"""
    if c not in self.char_to_idx or k <= 0:
        return -1
    
    target_idx = self.char_to_idx[c]
    
    # Find path from root to leaf
    path = []
    node = self.root
    while node and node.left is not None:
        mid = (node.alpha_lo + node.alpha_hi) // 2
        if target_idx <= mid:
            path.append((node, 0))  # Went left (bit 0)
            node = node.left
        else:
            path.append((node, 1))  # Went right (bit 1)
            node = node.right
    
    # Backtrack from leaf to root
    pos = k
    for node, bit in reversed(path):
        pos = bitvector_select(node.bitvector, bit, pos)
        if pos == -1:
            return -1
        pos += 1  # Convert to 1-indexed for next iteration
    
    return pos - 1  # Convert back to 0-indexed

For Select(‘a’, 3) on “abracadabra”, we trace the path to ‘a’s leaf, then backtrack: at each level, we find the position of the k-th 0 (or 1), which becomes k for the parent level.

Optimizing with Succinct Bitvectors

The naive bitvector operations above are O(n). Production implementations use succinct bitvectors achieving O(1) rank and select with sublinear auxiliary space.

For rank, we precompute cumulative popcount at block boundaries:

class SuccinctBitvector:
    BLOCK_SIZE = 64
    SUPERBLOCK_SIZE = 256
    
    def __init__(self, bits: List[int]):
        self.bits = bits
        self.n = len(bits)
        
        # Superblock cumulative counts
        self.superblocks = [0]
        # Block cumulative counts within superblock
        self.blocks = [0]
        
        cumulative = 0
        for i, bit in enumerate(bits):
            if i > 0 and i % self.SUPERBLOCK_SIZE == 0:
                self.superblocks.append(cumulative)
                self.blocks.append(0)
            elif i > 0 and i % self.BLOCK_SIZE == 0:
                local = cumulative - self.superblocks[-1]
                self.blocks.append(local)
            cumulative += bit
    
    def rank1(self, i: int) -> int:
        """O(1) rank for bit 1"""
        if i <= 0:
            return 0
        i = min(i, self.n)
        
        superblock_idx = i // self.SUPERBLOCK_SIZE
        block_idx = i // self.BLOCK_SIZE
        
        count = self.superblocks[superblock_idx]
        count += self.blocks[block_idx]
        
        # Scan remaining bits in current block
        start = block_idx * self.BLOCK_SIZE
        for j in range(start, i):
            count += self.bits[j]
        
        return count
    
    def rank0(self, i: int) -> int:
        return i - self.rank1(i)

The auxiliary structures use O(n / log n) bits—sublinear overhead. In practice, the final block scan uses CPU popcount instructions on 64-bit words, making it effectively O(1).

Applications and Complexity Analysis

Wavelet trees enable several powerful applications:

FM-Index: The backward search algorithm requires O(m) rank queries for pattern length m. With wavelet trees, each query costs O(log σ), yielding O(m log σ) pattern matching in compressed space.

Range Quantile Queries: Find the k-th smallest element in any subarray in O(log σ) time. The wavelet tree naturally supports this by tracking how elements partition at each level.

Document Retrieval: List all documents containing a pattern, ranked by frequency, using wavelet trees over document arrays.

Complexity Summary:

  • Construction: O(n log σ) time
  • Rank query: O(log σ) time
  • Select query: O(log σ) time
  • Space: n log σ + o(n log σ) bits

The space bound is near-optimal: storing n symbols over alphabet σ requires at least n log σ bits. The o(n log σ) term covers auxiliary structures for fast bitvector operations.

Wavelet trees exemplify the power of succinct data structures: we pay almost nothing beyond the raw data size yet gain logarithmic query times. For text indexing at scale, this tradeoff transforms theoretical algorithms into practical systems.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.