philosyang.com

Union Find (Disjoint Set Union) Walkthrough

Time to master union find to end this all.

Union Find, also known as Disjoint Set Union (DSU), is a data structure that keeps track of a partition of a set into disjoint (non-overlapping) subsets. It provides two primary operations: find and union. The find operation determines the root representative of the set containing a particular element, while the union operation merges two sets into a single set. This structure is particularly useful in scenarios where we need to determine the connected components of a graph, such as in Kruskal’s algorithm for finding the Minimum Spanning Tree or in solving the Longest Consecutive Sequence problem efficiently. The efficiency of Union Find can be significantly improved with techniques like path compression and union by rank, which help in keeping the tree flat and operations nearly constant time.

It’s a versatile and useful structure to tackle quite a range of problems, and the interviewer can expect you to have it in your toolbox.

naive implementation for Union Find

 1class DSU:  # Disjoint Set Union
 2    def __init__(self, size):
 3        # init parent array with each element as its own parent
 4        self.parent = list(range(size))
 5    
 6    def find(self, i):
 7        if self.parent[i] == i:
 8            return i
 9        
10        return self.find(self.parent[i])
11
12    def union(self, i, j):
13        i_parent = self.find(i)
14        j_parent = self.find(j)
15
16        self.parent[i_parent] = j_parent

Straightforward but not enough for tech rounds.

Path Compression in find()

1    def find(self, i):
2        # if I am the root of all connected
3        if self.parent[i] == i:
4            return i
5
6        # if I am not the root of all connected
7        self.parent[i] = self.find(self.parent[i])  # find the root using my parent and make it my parent (compression), also recursively sets common root to all my parents
8        return self.parent[i]

This assignment keeps all children directly pointing at a single root (rather than the naive method, which can be a linked list-like structure than a tree in worst case scenarios), thus compressing the height of the trees.

Union by Rank

We want to further optimize when performing union(). We need to decide whether to join A to B, or B to A, in order to minimize the height of the union’ed tree.

If A is shallow and B is deep (taller in height), we would want to join A into B rather than the opposite. Let me explain:

For example, if you have a tree A of height 3 and tree B of height 5, you will want to join A into B since it will make the overall maximum height still 5 (i.e., you are chaining a 3 (everything in A) after 1 (the root in B), which gets you 4 (the root in B, with everything in A), which is shallower than 5 (B’s height), which prevented your maximum height from growing).
If you do it the opposite way, the maximum height will now be max(3, 5+1) = 6.

If both have equal height, it is inevitable (either way) to have our maximum height += 1.

Rank does not necessarily equal to height. Rank is only an approximation:

Our current setup does not support an accurate height tracker, why?
Because we didn't do any updates to height during path compression, which we should, we could, but we needn't for DSU, why?As long as we keep track of a "height" that is accurate in a relative way (i.e., given rank >= height, rank A >= rank B can be reliably translated into height A >= height B), this is all we need.
 1    def __init__(self, n):
 2        self.rank = [0] * n  # we need another list to keep rough track of heights
 3        self.parent = list(range(n))
 4
 5    # def find() omitted
 6
 7    def union(self, i, j):
 8        i_parent = self.find(i)
 9        j_parent = self.find(j)
10
11        if i_parent == j_parent:
12            return
13        
14        # if tree i is shallower than j
15        if self.rank[i_parent] < self.rank[j_parent]:
16            # join tree i into j
17            self.parent[i_parent] = j_parent
18        elif self.rank[i_parent] > self.rank[j_parent]:
19            self.parent[j_parent] = i_parent
20        # if tree i and j have the same height
21        else:
22            # either works, but you need to increment rank for the chosen parent - "Uneasy lies the head that wears a crown."
23            self.parent[j_parent] = i_parent
24            self.rank[i_parent] += 1

optimized implementation

 1class DSU:
 2    def __init__(self, n):
 3        self.rank = [0] * n
 4        self.parent = list(range(n))
 5    
 6    def find(self, i):
 7        if self.parent[i] != i:
 8            self.parent[i] = self.find(self.parent[i])
 9
10        return self.parent[i]
11
12    
13    def union(self, i, j):
14        i_parent = self.find(i)
15        j_parent = self.find(j)
16
17        if i_parent == j_parent:
18            return
19        
20        i_parent_rank = self.rank[i_parent]
21        j_parent_rank = self.rank[j_parent]
22
23        if i_parent_rank < j_parent_rank:
24            self.parent[i_parent] = j_parent
25        elif i_parent_rank > j_parent_rank:
26            self.parent[j_parent] = i_parent
27        else:
28            self.parent[j_parent] = i_parent
29            self.rank[i_parent] += 1

Time Complexity:

I might expand on α(n), the inverse Ackermann function some day.

#Algorithm #Union-Find #Python