Skip to content

Commit a4c78da

Browse files
committed
Add Disjoint Set Union (Union by Size) implementation with doctests
1 parent a71618f commit a4c78da

1 file changed

Lines changed: 94 additions & 0 deletions

File tree

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
Disjoint Set (Union by Size).
3+
Reference: https://en.wikipedia.org/wiki/Disjoint-set_data_structure
4+
"""
5+
6+
from __future__ import annotations
7+
8+
9+
class Node:
10+
def __init__(self, data: int) -> None:
11+
self.data = data
12+
self.size: int
13+
self.parent: Node
14+
15+
16+
def make_set(x: Node) -> None:
17+
"""
18+
Make x as a set.
19+
20+
>>> v = Node(1)
21+
>>> make_set(v)
22+
>>> v.size
23+
1
24+
>>> v.parent == v
25+
True
26+
"""
27+
x.size = 1
28+
x.parent = x
29+
30+
31+
def find_set(x: Node) -> Node:
32+
"""
33+
Return the representative (parent) of the set containing x.
34+
35+
>>> v = Node(1)
36+
>>> make_set(v)
37+
>>> find_set(v) == v
38+
True
39+
"""
40+
if x != x.parent:
41+
x.parent = find_set(x.parent) # Path compression
42+
return x.parent
43+
44+
45+
def union_set(x: Node, y: Node) -> None:
46+
"""
47+
Union of two sets by size.
48+
The root with the larger size becomes the parent.
49+
50+
>>> v = [Node(i) for i in range(4)]
51+
>>> for node in v: make_set(node)
52+
>>> union_set(v[0], v[1])
53+
>>> union_set(v[2], v[3])
54+
>>> union_set(v[1], v[3])
55+
>>> find_set(v[0]) == find_set(v[2])
56+
True
57+
"""
58+
x, y = find_set(x), find_set(y)
59+
if x == y:
60+
return
61+
62+
if x.size < y.size:
63+
x, y = y, x
64+
y.parent = x
65+
x.size += y.size
66+
67+
68+
def test_disjoint_set() -> None:
69+
"""
70+
>>> test_disjoint_set()
71+
"""
72+
vertex = [Node(i) for i in range(6)]
73+
for v in vertex:
74+
make_set(v)
75+
76+
union_set(vertex[0], vertex[1])
77+
union_set(vertex[1], vertex[2])
78+
union_set(vertex[3], vertex[4])
79+
union_set(vertex[3], vertex[5])
80+
81+
# After unions, sets should be {0,1,2} and {3,4,5}
82+
assert find_set(vertex[0]) == find_set(vertex[1])
83+
assert find_set(vertex[1]) == find_set(vertex[2])
84+
assert find_set(vertex[3]) == find_set(vertex[4])
85+
assert find_set(vertex[4]) == find_set(vertex[5])
86+
87+
assert find_set(vertex[0]) != find_set(vertex[3])
88+
89+
90+
if __name__ == "__main__":
91+
import doctest
92+
93+
doctest.testmod()
94+
test_disjoint_set()

0 commit comments

Comments
 (0)