philosyang.com

128. Longest Consecutive Sequence: Union Find

We first scaffold our union find:

 1class Solution:
 2    def longestConsecutive(self, nums: List[int]) -> int:
 3        
 4        class DSU:
 5            def __init__(self, n):
 6                self.rank = [0] * n
 7                self.parent = list(range(n))
 8            
 9            def find(self, i):
10                if self.parent[i] != i:
11                    self.parent[i] = self.find(self.parent[i])
12                return self.parent[i]
13            
14            def union(self, i, j):
15                ip = self.find(i)
16                jp = self.find(j)
17
18                if ip == jp:
19                    return
20                
21                if self.rank[ip] < self.rank[jp]:
22                    self.parent[ip] = jp
23                elif self.rank[ip] > self.rank[jp]:
24                    self.parent[jp] = ip
25                else:
26                    self.parent[jp] = ip
27                    self.rank[ip] += 1
28
29        return -1

we need a way to actually use our DSU in this problem.

In this problem, we want to union all consecutive numbers. Furthermore, we also need to keep track of the sizes.

 1class Solution:
 2    def longestConsecutive(self, nums: List[int]) -> int:
 3        
 4        class DSU:
 5            def __init__(self, n):
 6                self.rank = [0] * n
 7                self.parent = list(range(n))
 8                self.size = [1] * n # size connected; singleton for all num at first
 9            
10            def find(self, i):
11                if self.parent[i] != i:
12                    self.parent[i] = self.find(self.parent[i])
13                return self.parent[i]
14            
15            def union(self, i, j):
16                ip = self.find(i)
17                jp = self.find(j)
18
19                if ip == jp:
20                    return
21                
22                if self.rank[ip] < self.rank[jp]:
23                    self.parent[ip] = jp
24                    self.size[jp] += self.size[ip]  # update sizes upon union
25                elif self.rank[ip] > self.rank[jp]:
26                    self.parent[jp] = ip
27                    self.size[ip] += self.size[jp]
28                else:
29                    self.parent[jp] = ip
30                    self.rank[ip] += 1
31                    self.size[ip] += self.size[jp]

We need a way to put the nums into our DSU.

1        n = len(nums)
2        uf = DSU(n)
3
4        for i,num in enumerate(nums):
5            if num + 1 in ...

We need a hash-typed structure to reduce our check down to O(1).

1        n = len(nums)
2        lookup = set(nums)  # hashset
3        uf = DSU(n)
4
5        for i,num in enumerate(nums):
6            if num + 1 in lookup:
7                uf.union(i, i+1)
8        
9        return max(uf.size)

The above code is incorrect: we actually want the index of num+1 (not i+1) - we want to union the indices of these two nums.

We thus think about upgrading our hashset to a hashmap.

1        n = len(nums)
2        lookup = {num: idx for idx, num in enumerate(nums)} # free real estate
3        uf = DSU(n)
4
5        for i, num in enumerate(nums):
6            if num + 1 in lookup.keys():
7                uf.union(i, lookup[num + 1])
8
9        return max(uf.size)

We are close, yet this code is still incorrect.

When the same number shows up twice, we ended up counting it twice in the component’s size (which is wrong since we want the size of longest consecutive - [1,2,2,3] should return 3).

Therefore we think about a way to dedupe nums, deduping won’t hurt because it doesn’t change the longest consecutive sequence.

 1        uniq = set(nums)    # dedupe
 2        n = len(uniq)
 3        lookup = {num: idx for idx, num in enumerate(uniq)}
 4        uf = DSU(n)
 5
 6        for i, num in enumerate(uniq):
 7            if num + 1 in uniq:
 8                uf.union(i, lookup[num + 1])
 9
10        return max(uf.size)

besides from the above method, we can also choose to skip duplicates while building lookup if we don’t want to dedupe. (Personally I prefer the dedupe.)

1        lookup = {}
2        for i, num in enumerate(nums):
3            if num not in lookup: # ignore duplicates
4                lookup[num] = i

Now we have a correct code with O(n × α(n)) time.

 1class Solution:
 2    def longestConsecutive(self, nums: List[int]) -> int:
 3        if not nums:
 4            return 0
 5
 6        class DSU:
 7            def __init__(self, n):
 8                self.rank = [0] * n
 9                self.parent = list(range(n))
10                self.size = [1] * n  # size connected for this num
11
12            def find(self, i):
13                if self.parent[i] != i:
14                    self.parent[i] = self.find(self.parent[i])
15                return self.parent[i]
16
17            def union(self, i, j):
18                ip = self.find(i)
19                jp = self.find(j)
20
21                if ip == jp:
22                    return
23
24                if self.rank[ip] < self.rank[jp]:
25                    self.parent[ip] = jp
26                    self.size[jp] += self.size[ip]
27                elif self.rank[ip] > self.rank[jp]:
28                    self.parent[jp] = ip
29                    self.size[ip] += self.size[jp]
30                else:
31                    self.parent[jp] = ip
32                    self.rank[ip] += 1
33                    self.size[ip] += self.size[jp]
34
35        uniq = set(nums)
36        n = len(uniq)
37        lookup = {num: idx for idx, num in enumerate(uniq)}
38        uf = DSU(n)
39
40        for i, num in enumerate(uniq):
41            if num + 1 in uniq:
42                uf.union(i, lookup[num + 1])
43
44        return max(uf.size)

We can think about improving this code. The lowest hanging fruit is that we can ditch rank and rely solely on size to union.

 1class Solution:
 2    def longestConsecutive(self, nums: List[int]) -> int:
 3        if not nums:
 4            return 0
 5
 6        class DSU:
 7            def __init__(self, n):
 8                self.parent = list(range(n))
 9                self.size = [1] * n  # size connected for this num
10
11            def find(self, i):
12                if self.parent[i] != i:
13                    self.parent[i] = self.find(self.parent[i])
14                return self.parent[i]
15
16            def union(self, i, j):
17                ip = self.find(i)
18                jp = self.find(j)
19
20                if ip == jp:
21                    return
22
23                if self.size[ip] < self.size[jp]:
24                    self.parent[ip] = jp
25                    self.size[jp] += self.size[ip]
26                else:   # > and = have exactly the same operations w/o rank
27                    self.parent[jp] = ip
28                    self.size[ip] += self.size[jp]
29
30        uniq = set(nums)
31        n = len(uniq)
32        lookup = {num: idx for idx, num in enumerate(uniq)}
33        uf = DSU(n)
34
35        for i, num in enumerate(uniq):
36            if num + 1 in uniq:
37                uf.union(i, lookup[num + 1])
38
39        return max(uf.size)

We can also build lookup and union in the same loop, but our union find is correct and concise enough for this practice.

Union find is cool!

#Neetcode150 #Array #Union-Find #Python