Union-Find with Path Compression and Union by Rank

Union-Find, also known as Disjoint Set Union (DSU), is a data structure that tracks a collection of non-overlapping sets. It supports two primary operations: finding which set an element belongs to,...

Key Insights

  • Union-Find with path compression and union by rank achieves near-constant O(α(n)) amortized time per operation, where α is the inverse Ackermann function—effectively constant for any practical input size.
  • Path compression flattens the tree structure during find operations, making subsequent queries faster; union by rank prevents tree imbalance during merges.
  • The combined optimizations transform a potentially O(n) worst-case operation into one of the most efficient data structures in computer science, essential for graph algorithms and connectivity problems.

Introduction to Union-Find

Union-Find, also known as Disjoint Set Union (DSU), is a data structure that tracks a collection of non-overlapping sets. It supports two primary operations: finding which set an element belongs to, and merging two sets together.

The elegance of Union-Find lies in its simplicity and efficiency. You’ll encounter it constantly in graph algorithms—Kruskal’s minimum spanning tree algorithm relies on it to detect cycles, network connectivity problems use it to track component membership, and image processing applications leverage it for connected component labeling.

The structure represents each set as a tree, where each element points to a parent. The root of each tree serves as the representative (or “leader”) of that set. When we need to check if two elements belong to the same set, we simply compare their roots.

Naive Implementation and Its Limitations

Let’s start with the straightforward approach. Each element maintains a parent pointer, and we traverse up the tree to find the root.

class NaiveUnionFind:
    def __init__(self, n: int):
        # Each element is its own parent initially
        self.parent = list(range(n))
    
    def find(self, x: int) -> int:
        # Traverse up to find root
        while self.parent[x] != x:
            x = self.parent[x]
        return x
    
    def union(self, x: int, y: int) -> None:
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x != root_y:
            # Arbitrarily attach one tree to another
            self.parent[root_x] = root_y

This works, but there’s a critical flaw. Consider what happens when we union elements in a degenerate order:

uf = NaiveUnionFind(5)
uf.union(0, 1)  # 0 -> 1
uf.union(1, 2)  # 1 -> 2, so 0 -> 1 -> 2
uf.union(2, 3)  # 2 -> 3, so 0 -> 1 -> 2 -> 3
uf.union(3, 4)  # Creates a chain: 0 -> 1 -> 2 -> 3 -> 4

We’ve created a linked list. Now find(0) requires traversing four edges. With n elements, both find and union degrade to O(n) time. For algorithms processing millions of edges, this is unacceptable.

Path Compression Optimization

Path compression addresses the find inefficiency by flattening the tree structure during traversal. The insight is simple: once we’ve found the root, we might as well update every node we visited to point directly to it.

Here’s the recursive implementation:

def find_recursive(self, x: int) -> int:
    if self.parent[x] != x:
        self.parent[x] = self.find_recursive(self.parent[x])
    return self.parent[x]

This elegantly compresses the path in a single pass. After calling find(0) on our degenerate chain, every node along the path now points directly to the root.

The iterative version requires two passes but avoids recursion depth issues:

def find_iterative(self, x: int) -> int:
    # First pass: find root
    root = x
    while self.parent[root] != root:
        root = self.parent[root]
    
    # Second pass: compress path
    while self.parent[x] != root:
        next_parent = self.parent[x]
        self.parent[x] = root
        x = next_parent
    
    return root

There’s also a middle-ground technique called path halving that compresses every other node:

def find_path_halving(self, x: int) -> int:
    while self.parent[x] != x:
        self.parent[x] = self.parent[self.parent[x]]  # Skip one level
        x = self.parent[x]
    return x

Path halving achieves the same asymptotic complexity with slightly less work per operation. In practice, full path compression often performs better due to more aggressive flattening.

Union by Rank Optimization

Path compression optimizes find, but we can also be smarter about how we merge trees. Union by rank tracks the “rank” of each tree (an upper bound on its height) and always attaches the shorter tree under the taller one.

class UnionFindByRank:
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n  # Height upper bound
    
    def find(self, x: int) -> int:
        while self.parent[x] != x:
            x = self.parent[x]
        return x
    
    def union(self, x: int, y: int) -> bool:
        root_x = self.find(x)
        root_y = self.find(y)
        
        if root_x == root_y:
            return False  # Already in same set
        
        # Attach smaller rank tree under larger rank tree
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            # Same rank: pick one, increment its rank
            self.parent[root_y] = root_x
            self.rank[root_x] += 1
        
        return True

The key insight: when ranks differ, the merged tree’s height doesn’t increase. Only when ranks are equal does the height grow by one. This guarantees the tree height is at most O(log n), giving us O(log n) find operations even without path compression.

An alternative is union by size, which tracks the number of elements rather than height. The logic is similar—attach the smaller set under the larger one:

def union_by_size(self, x: int, y: int) -> bool:
    root_x = self.find(x)
    root_y = self.find(y)
    
    if root_x == root_y:
        return False
    
    # Attach smaller tree under larger tree
    if self.size[root_x] < self.size[root_y]:
        self.parent[root_x] = root_y
        self.size[root_y] += self.size[root_x]
    else:
        self.parent[root_y] = root_x
        self.size[root_x] += self.size[root_y]
    
    return True

Both approaches achieve the same complexity bounds. Union by size has the advantage of tracking actual component sizes, which is often useful information.

Combined Implementation

Here’s a production-ready implementation combining both optimizations:

class UnionFind:
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.count = n  # Number of disjoint sets
    
    def find(self, x: int) -> int:
        """Find with path compression."""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x: int, y: int) -> bool:
        """Union by rank. Returns True if merge occurred."""
        root_x = self.find(x)
        root_y = self.find(y)
        
        if root_x == root_y:
            return False
        
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            self.parent[root_y] = root_x
            self.rank[root_x] += 1
        
        self.count -= 1
        return True
    
    def connected(self, x: int, y: int) -> bool:
        """Check if two elements are in the same set."""
        return self.find(x) == self.find(y)
    
    def get_count(self) -> int:
        """Return number of disjoint sets."""
        return self.count

With both optimizations, operations run in O(α(n)) amortized time, where α is the inverse Ackermann function. This function grows so slowly that α(n) ≤ 4 for any n less than 10^600. For all practical purposes, this is constant time.

Practical Applications

Let’s solve a classic problem: counting connected components in an undirected graph.

def count_components(n: int, edges: list[list[int]]) -> int:
    """
    Given n nodes labeled 0 to n-1 and a list of edges,
    return the number of connected components.
    """
    uf = UnionFind(n)
    
    for u, v in edges:
        uf.union(u, v)
    
    return uf.get_count()

# Example usage
edges = [[0, 1], [1, 2], [3, 4]]
print(count_components(5, edges))  # Output: 2
# Components: {0, 1, 2} and {3, 4}

Here’s another common application—detecting cycles in an undirected graph:

def has_cycle(n: int, edges: list[list[int]]) -> bool:
    """
    Returns True if the undirected graph contains a cycle.
    """
    uf = UnionFind(n)
    
    for u, v in edges:
        if uf.connected(u, v):
            # Edge connects nodes already in same component = cycle
            return True
        uf.union(u, v)
    
    return False

# Example
print(has_cycle(3, [[0, 1], [1, 2], [2, 0]]))  # True (triangle)
print(has_cycle(3, [[0, 1], [1, 2]]))          # False (path)

Complexity Analysis and Comparison

Here’s how the optimizations stack up:

Implementation Find (worst) Find (amortized) Union (worst) Space
Naive O(n) O(n) O(n) O(n)
Path Compression Only O(log n) O(log n) O(log n) O(n)
Union by Rank Only O(log n) O(log n) O(log n) O(n)
Both Combined O(log n) O(α(n)) O(α(n)) O(n)

The combined approach’s amortized O(α(n)) bound is remarkable. For a sequence of m operations on n elements, the total time is O(m · α(n)), which is effectively O(m) for any realistic input.

Space complexity remains O(n) for all variants—we need the parent array regardless, and the rank array adds only linear overhead.

When implementing Union-Find, always use both optimizations. The additional code complexity is minimal, and the performance difference is substantial for large inputs. Path compression alone gets you most of the way there, but union by rank ensures consistent performance even for adversarial input patterns.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.