You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
tech-interview-handbook/experimental/utilities/python/union_find.py

51 lines
1.3 KiB

## Union-Find data structure
## https://en.wikipedia.org/wiki/Disjoint-set_data_structure
parents = [0, 1, 2, 3, 4, 5, 6] # parent[i] is the parent of i
weights = [1, 1, 1, 1, 1, 1, 1]
def find_root(parents, p):
'''Average: O(log n)'''
root = p
while parents[root] != root:
root = parents[root]
# Flatten tree
while parents[p] != p:
parents[p], p = root, parents[p]
return root
def union(parents, p, q):
'''Average: O(log n)'''
p = find_root(parents, p)
q = find_root(parents, q)
# Link the smaller node to the larger node
if weights[p] > weights[q]:
parents[q] = p
weights[p] += weights[q]
else:
parents[p] = q
weights[q] += weights[p]
# Start with all elements separate
# -> [0], [1], [2], [3], [4], [5], [6]
print(find_root(parents, 2) == 2)
# Merge 1, 2, 3 and 4, 5, 6
# -> [0], [1, 2, 3], [4, 5, 6]
union(parents, 1, 2)
union(parents, 2, 3)
union(parents, 4, 5)
union(parents, 4, 6)
# Roots of 1, 2, 3 and 4, 5, 6 are the same
print(find_root(parents, 0))
print(list(find_root(parents, i) for i in (1, 2, 3)))
print(list(find_root(parents, i) for i in (4, 5, 6)))
# Merge 2, 4
# -> [0], [1, 2, 3, 4, 5, 6]
union(parents, 2, 4)
print(list(find_root(parents, i) for i in (1, 2, 3, 4, 5, 6)))