Disjoint Set Union: Union-Find Implementation
The Disjoint Set Union (DSU) data structure, commonly called Union-Find, solves a deceptively simple problem: tracking which elements belong to the same group when groups can merge but never split....
Key Insights
- Union-Find with path compression and union by rank achieves near-constant O(α(n)) amortized time per operation, making it one of the most efficient data structures for dynamic connectivity problems.
- Path compression alone provides massive performance gains by flattening tree structures during find operations—always implement it, as the overhead is negligible.
- The data structure excels at problems involving equivalence relations, connected components, and cycle detection, but cannot efficiently handle edge deletions without significant modifications.
Introduction to Disjoint Sets
The Disjoint Set Union (DSU) data structure, commonly called Union-Find, solves a deceptively simple problem: tracking which elements belong to the same group when groups can merge but never split. You start with n elements, each in its own set, and need to efficiently answer two questions: “Are these two elements in the same set?” and “Merge these two sets.”
This comes up constantly in practice. Network connectivity problems ask whether two nodes can communicate. Kruskal’s minimum spanning tree algorithm needs to know if adding an edge would create a cycle. Image segmentation groups adjacent pixels with similar colors. Social network analysis identifies friend clusters. Any time you’re partitioning elements and merging partitions, Union-Find should be your first thought.
The beauty of Union-Find lies in its simplicity. The naive implementation takes five minutes to write. Add two optimizations, and you have a data structure with effectively constant-time operations that handles millions of elements without breaking a sweat.
Core Operations: Find and Union
The data structure maintains a forest of trees, where each tree represents one set. Every element points to a parent, and the root of each tree serves as the set’s representative (or “leader”). Two elements belong to the same set if and only if they have the same root.
Here’s the naive implementation:
class UnionFind:
def __init__(self, n: int):
# Each element starts as its own parent (its own set)
self.parent = list(range(n))
def find(self, x: int) -> int:
"""Find the root/representative of x's set."""
while self.parent[x] != x:
x = self.parent[x]
return x
def union(self, x: int, y: int) -> bool:
"""Merge sets containing x and y. Returns True if they were separate."""
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False # Already in the same set
self.parent[root_x] = root_y
return True
def connected(self, x: int, y: int) -> bool:
"""Check if x and y are in the same set."""
return self.find(x) == self.find(y)
This works, but has a critical flaw. Each union operation just attaches one root to another, potentially creating long chains. In the worst case, you end up with a linked list, making find() operations O(n). After m operations on n elements, you’re looking at O(mn) total time—unacceptable for large inputs.
Path Compression Optimization
Path compression fixes the “long chain” problem with a simple insight: when you traverse from a node to its root, you’ve done the work of finding the root for every node along that path. Why not update all of them to point directly to the root?
def find(self, x: int) -> int:
"""Find with path compression - all nodes point directly to root."""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
This recursive implementation is elegant: it finds the root, then on the way back up the call stack, updates each node’s parent to point directly to the root. After one find operation, the entire path becomes flat.
If you prefer iterative code (useful for very deep trees to avoid stack overflow):
def find_iterative(self, x: int) -> int:
"""Iterative find with path compression."""
root = x
while self.parent[root] != root:
root = self.parent[root]
# Second pass: compress the path
while self.parent[x] != root:
next_x = self.parent[x]
self.parent[x] = root
x = next_x
return root
Path compression alone brings the amortized time per operation down to O(log n). That’s already a massive improvement, but we can do better.
Union by Rank/Size Optimization
The second optimization prevents trees from becoming unbalanced in the first place. Instead of arbitrarily attaching one root to another, we attach the smaller tree under the larger one.
Two variants exist: union by rank (tree height) and union by size (node count). Union by size is often more practical because size is useful for other purposes, like knowing how many elements are in each set.
class UnionFind:
def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n # Tree height (upper bound)
self.size = [1] * n # Number of elements in each set
def find(self, x: int) -> int:
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union_by_rank(self, x: int, y: int) -> bool:
root_x, root_y = self.find(x), self.find(y)
if root_x == root_y:
return False
# Attach smaller rank tree under larger rank tree
if self.rank[root_x] < self.rank[root_y]:
root_x, root_y = root_y, root_x
self.parent[root_y] = root_x
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
return True
def union_by_size(self, x: int, y: int) -> bool:
root_x, root_y = self.find(x), self.find(y)
if root_x == root_y:
return False
# Attach smaller tree under larger tree
if self.size[root_x] < self.size[root_y]:
root_x, root_y = root_y, root_x
self.parent[root_y] = root_x
self.size[root_x] += self.size[root_y]
return True
def get_size(self, x: int) -> int:
"""Return the size of the set containing x."""
return self.size[self.find(x)]
With union by rank, the rank only increases when merging two trees of equal rank. This keeps trees shallow—the maximum height is O(log n) even without path compression.
Time Complexity Analysis
When you combine path compression with union by rank or size, something remarkable happens. The amortized time per operation drops to O(α(n)), where α is the inverse Ackermann function.
The Ackermann function grows absurdly fast—A(4, 4) has more digits than atoms in the observable universe. Its inverse, therefore, grows absurdly slowly. For any practical input size (n < 10^80), α(n) ≤ 4. For all intents and purposes, it’s constant.
This means m operations on n elements take O(m × α(n)) ≈ O(m) time. You won’t find a simpler data structure with better theoretical guarantees for this problem.
Practical Applications
Cycle Detection in Undirected Graphs
When processing edges, a cycle exists if both endpoints are already in the same set:
def has_cycle(n: int, edges: list[tuple[int, int]]) -> bool:
uf = UnionFind(n)
for u, v in edges:
if uf.find(u) == uf.find(v):
return True # u and v already connected = cycle
uf.union_by_size(u, v)
return False
Finding Connected Components
Count the number of distinct roots after processing all edges:
def count_components(n: int, edges: list[tuple[int, int]]) -> int:
uf = UnionFind(n)
for u, v in edges:
uf.union_by_size(u, v)
# Count unique roots
return sum(1 for i in range(n) if uf.find(i) == i)
Kruskal’s Minimum Spanning Tree
Sort edges by weight, then greedily add edges that don’t create cycles:
def kruskal_mst(n: int, edges: list[tuple[int, int, int]]) -> list:
"""edges: list of (u, v, weight)"""
uf = UnionFind(n)
mst = []
for u, v, weight in sorted(edges, key=lambda e: e[2]):
if uf.union_by_size(u, v):
mst.append((u, v, weight))
if len(mst) == n - 1:
break
return mst
Variations and Extensions
Weighted Union-Find
Sometimes you need to track relationships between elements, not just connectivity. Weighted Union-Find stores a value (often a distance or ratio) on each edge to the parent:
class WeightedUnionFind:
def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n
self.diff = [0.0] * n # diff[x] = value[x] - value[parent[x]]
def find(self, x: int) -> tuple[int, float]:
"""Returns (root, value[x] - value[root])."""
if self.parent[x] == x:
return x, 0.0
root, d = self.find(self.parent[x])
self.parent[x] = root
self.diff[x] += d
return root, self.diff[x]
def union(self, x: int, y: int, w: float) -> bool:
"""Assert that value[x] - value[y] = w."""
root_x, diff_x = self.find(x)
root_y, diff_y = self.find(y)
if root_x == root_y:
return abs(diff_x - diff_y - w) < 1e-9 # Check consistency
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
self.diff[root_x] = diff_y + w - diff_x
else:
self.parent[root_y] = root_x
self.diff[root_y] = diff_x - w - diff_y
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
return True
This handles problems like “A is 3 units heavier than B, B is 2 units heavier than C, how much heavier is A than C?”
Union-Find with Rollback
For competitive programming scenarios requiring undo operations, skip path compression (it’s irreversible) and maintain a stack of operations:
class RollbackUnionFind:
def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n
self.history = [] # Stack of (node, old_parent, old_rank)
def find(self, x: int) -> int:
# No path compression!
while self.parent[x] != x:
x = self.parent[x]
return x
def union(self, x: int, y: int) -> bool:
root_x, root_y = self.find(x), self.find(y)
if root_x == root_y:
return False
if self.rank[root_x] < self.rank[root_y]:
root_x, root_y = root_y, root_x
self.history.append((root_y, self.parent[root_y], self.rank[root_x]))
self.parent[root_y] = root_x
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
return True
def rollback(self):
if self.history:
node, old_parent, old_rank = self.history.pop()
self.parent[node] = old_parent
self.rank[self.find(node)] = old_rank
Without path compression, operations are O(log n) instead of O(α(n)), but that’s the price of reversibility.
Union-Find is one of those rare data structures that’s both theoretically elegant and immediately practical. Master it, and you’ll recognize opportunities to apply it everywhere—from graph algorithms to distributed systems to game development. The implementation is short enough to write from memory, and the performance is hard to beat.