Giao diện
Disjoint Set Union (DSU) - The Connected Kingdom
"Union-Find là thuật toán đơn giản nhất mà mạnh nhất. Master nó, và bạn sẽ thấy nó ở khắp nơi." - HPN
Problem Statement
Quản lý một tập hợp các phần tử chia thành nhiều nhóm không giao nhau (disjoint sets).
Hai thao tác cơ bản:
- Find(x): Tìm nhóm (representative) của phần tử x
- Union(x, y): Hợp nhất nhóm chứa x và nhóm chứa y
Initial: {0}, {1}, {2}, {3}, {4}
Union(0, 1): {0, 1}, {2}, {3}, {4}
Union(2, 3): {0, 1}, {2, 3}, {4}
Union(0, 3): {0, 1, 2, 3}, {4}
Find(2) == Find(1)? → YES (cùng nhóm)
Find(4) == Find(0)? → NO (khác nhóm)Real-World Applications
| Domain | Use Case | Chi tiết |
|---|---|---|
| Social Networks | Friends of Friends | Tìm connected components |
| Image Processing | Pixel clustering | Segmentation |
| Networking | Network connectivity | Kiểm tra 2 nodes có thể communicate không |
| Kruskal's MST | Minimum Spanning Tree | Detect cycles khi thêm edges |
| Game Development | Dynamic connectivity | Terrain/region management |
| Compiler | Type inference | Unifying type variables |
Visualization
Naive Implementation (Slow)
python
class NaiveDSU:
"""
❌ SLOW - O(N) per operation.
Chỉ để hiểu concept, KHÔNG dùng trong production!
"""
def __init__(self, n: int):
self.parent = list(range(n))
def find(self, x: int) -> int:
"""Find root by following parent pointers."""
while self.parent[x] != x:
x = self.parent[x]
return x
def union(self, x: int, y: int):
"""Union by connecting roots."""
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
self.parent[root_x] = root_yVấn đề: Cây có thể trở nên rất dài (skewed tree) → O(N) cho mỗi find!
Worst case:
0 → 1 → 2 → 3 → 4 → 5 → ... → N-1Two Key Optimizations
1. Path Compression
Idea: Khi find, flatten tree bằng cách nối trực tiếp tất cả nodes về root.
Before find(0): After find(0):
4 4
↑ ↗ ↑ ↖
3 0 3 (others)
↑ ↑
2 2
↑ ↑
1 1
↑
0
One traversal → Tree becomes flat!python
def find(self, x: int) -> int:
"""Find with path compression."""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Recursive compression
return self.parent[x]2. Union by Rank/Size
Idea: Khi union, luôn nối cây nhỏ hơn vào cây lớn hơn.
- Union by Rank: Dựa trên chiều cao cây
- Union by Size: Dựa trên số nodes
python
def union(self, x: int, y: int) -> bool:
"""Union by rank. Returns True if union happened."""
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False # Already in same set
# Attach smaller tree under larger tree
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
return TrueOptimized Implementation
python
from typing import List, Tuple, Optional
class DSU:
"""
Disjoint Set Union with Path Compression + Union by Rank.
Time Complexity per operation: O(α(N)) ≈ O(1) amortized
where α is inverse Ackermann function (grows EXTREMELY slowly)
For practical purposes: O(1) per operation!
"""
def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n
self.size = [1] * n # Track size of each set
self.num_sets = n # Number of disjoint sets
def find(self, x: int) -> int:
"""
Find representative with path compression.
Time: O(α(N)) amortized ≈ O(1)
"""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x: int, y: int) -> bool:
"""
Union by rank.
Returns: True if union happened (different sets), False if same set.
"""
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
# Union by rank
if self.rank[root_x] < self.rank[root_y]:
root_x, root_y = root_y, root_x
self.parent[root_y] = root_x
self.size[root_x] += self.size[root_y]
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
self.num_sets -= 1
return True
def connected(self, x: int, y: int) -> bool:
"""Check if x and y are in the same set."""
return self.find(x) == self.find(y)
def get_size(self, x: int) -> int:
"""Get size of set containing x."""
return self.size[self.find(x)]
def get_num_sets(self) -> int:
"""Get number of disjoint sets."""
return self.num_setsProduction Code
python
# HPN Engineering Standard
# Implementation: Disjoint Set Union - Full Featured
from typing import List, Tuple, Dict, Set, Optional
from collections import defaultdict
class DSU:
"""
Production-ready DSU with all features.
Optimizations:
- Path compression
- Union by rank
Features:
- Get all sets
- Set size tracking
- Undo operations (optional)
"""
def __init__(self, n: int):
self.n = n
self.parent = list(range(n))
self.rank = [0] * n
self.size = [1] * n
self.num_sets = n
def find(self, x: int) -> int:
"""Find with path compression."""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x: int, y: int) -> bool:
"""Union by rank. Returns True if merged."""
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
self.size[px] += self.size[py]
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
self.num_sets -= 1
return True
def connected(self, x: int, y: int) -> bool:
return self.find(x) == self.find(y)
def get_size(self, x: int) -> int:
return self.size[self.find(x)]
def get_all_sets(self) -> Dict[int, List[int]]:
"""Get all sets as {root: [members]}."""
sets = defaultdict(list)
for i in range(self.n):
sets[self.find(i)].append(i)
return dict(sets)
# ============================================
# APPLICATION: KRUSKAL'S MST
# ============================================
def kruskal_mst(
n: int,
edges: List[Tuple[int, int, float]]
) -> Tuple[float, List[Tuple[int, int, float]]]:
"""
Kruskal's Minimum Spanning Tree using DSU.
Args:
n: Number of vertices
edges: [(u, v, weight), ...]
Returns:
(total_weight, mst_edges)
Time: O(E log E) for sorting + O(E α(V)) for DSU
"""
# Sort edges by weight
sorted_edges = sorted(edges, key=lambda e: e[2])
dsu = DSU(n)
mst_edges = []
total_weight = 0
for u, v, weight in sorted_edges:
if dsu.union(u, v):
mst_edges.append((u, v, weight))
total_weight += weight
# MST has exactly V-1 edges
if len(mst_edges) == n - 1:
break
return total_weight, mst_edges
# ============================================
# APPLICATION: CONNECTED COMPONENTS
# ============================================
def find_connected_components(
n: int,
edges: List[Tuple[int, int]]
) -> List[List[int]]:
"""
Find all connected components in an undirected graph.
Time: O(V + E × α(V)) ≈ O(V + E)
"""
dsu = DSU(n)
for u, v in edges:
dsu.union(u, v)
return list(dsu.get_all_sets().values())
def is_graph_connected(n: int, edges: List[Tuple[int, int]]) -> bool:
"""Check if graph is connected."""
if n <= 1:
return True
dsu = DSU(n)
for u, v in edges:
dsu.union(u, v)
return dsu.get_num_sets() == 1
# ============================================
# APPLICATION: SOCIAL NETWORK (Friends of Friends)
# ============================================
def find_friend_groups(
users: List[str],
friendships: List[Tuple[str, str]]
) -> List[List[str]]:
"""
Find friend groups in social network.
Two users are in same group if they're friends,
or friends of friends, etc.
"""
user_to_id = {user: i for i, user in enumerate(users)}
n = len(users)
dsu = DSU(n)
for u, v in friendships:
dsu.union(user_to_id[u], user_to_id[v])
groups = dsu.get_all_sets()
return [[users[i] for i in members] for members in groups.values()]
def count_mutual_friends(
users: List[str],
friendships: List[Tuple[str, str]],
user1: str,
user2: str
) -> int:
"""
Check if two users are in same friend network.
Returns group size if connected, 0 if not.
"""
user_to_id = {user: i for i, user in enumerate(users)}
dsu = DSU(len(users))
for u, v in friendships:
dsu.union(user_to_id[u], user_to_id[v])
id1, id2 = user_to_id[user1], user_to_id[user2]
if dsu.connected(id1, id2):
return dsu.get_size(id1)
return 0
# ============================================
# APPLICATION: DYNAMIC CONNECTIVITY
# ============================================
class DynamicConnectivity:
"""
Track connectivity as edges are added dynamically.
⚠️ Note: Standard DSU only supports ADD, not DELETE.
For delete support, need more complex data structures.
"""
def __init__(self, n: int):
self.dsu = DSU(n)
self.edge_count = 0
def add_edge(self, u: int, v: int) -> bool:
"""
Add edge. Returns True if it connects two components.
"""
self.edge_count += 1
return self.dsu.union(u, v)
def is_connected(self, u: int, v: int) -> bool:
return self.dsu.connected(u, v)
def get_component_count(self) -> int:
return self.dsu.get_num_sets()
def is_fully_connected(self) -> bool:
return self.dsu.get_num_sets() == 1
# ============================================
# USAGE EXAMPLE
# ============================================
if __name__ == "__main__":
# Example 1: Basic DSU
print("=== Basic DSU ===")
dsu = DSU(5)
dsu.union(0, 1)
dsu.union(2, 3)
dsu.union(1, 3)
print(f"0 and 2 connected? {dsu.connected(0, 2)}") # True
print(f"0 and 4 connected? {dsu.connected(0, 4)}") # False
print(f"Number of sets: {dsu.get_num_sets()}") # 2
print(f"All sets: {dsu.get_all_sets()}")
# Example 2: Kruskal's MST
print("\n=== Kruskal's MST ===")
edges = [
(0, 1, 4), (0, 7, 8), (1, 2, 8), (1, 7, 11),
(2, 3, 7), (2, 5, 4), (2, 8, 2), (3, 4, 9),
(3, 5, 14), (4, 5, 10), (5, 6, 2), (6, 7, 1),
(6, 8, 6), (7, 8, 7)
]
total, mst = kruskal_mst(9, edges)
print(f"MST weight: {total}")
print(f"MST edges: {mst}")
# Example 3: Social Network
print("\n=== Friend Groups ===")
users = ["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]
friendships = [
("Alice", "Bob"),
("Bob", "Charlie"),
("David", "Eve"),
]
groups = find_friend_groups(users, friendships)
for i, group in enumerate(groups):
print(f" Group {i+1}: {group}")
# Example 4: Connected Components
print("\n=== Graph Connectivity ===")
n = 6
graph_edges = [(0, 1), (1, 2), (3, 4)]
print(f"Is connected? {is_graph_connected(n, graph_edges)}")
print(f"Components: {find_connected_components(n, graph_edges)}")Complexity Analysis
| Operation | Naive | With Both Optimizations |
|---|---|---|
| Find | ||
| Union | ||
| M operations |
📘 Inverse Ackermann Function α(N)
Trong thực tế: Coi như O(1)!
When to Use DSU
| Problem Type | Use DSU? | Alternative |
|---|---|---|
| Dynamic connectivity | ✅ Yes | - |
| Static connected components | ✅ Yes | BFS/DFS cũng được |
| Kruskal's MST | ✅ Yes | - |
| Cycle detection (undirected) | ✅ Yes | DFS |
| Need to DELETE edges | ❌ No | Link-Cut Tree |
| Directed graph connectivity | ❌ No | Tarjan's SCC |
Common Patterns
python
# Pattern 1: Check if adding edge creates cycle
if dsu.connected(u, v):
print("Adding this edge creates a cycle!")
else:
dsu.union(u, v)
# Pattern 2: Count components after all operations
dsu = DSU(n)
for u, v in edges:
dsu.union(u, v)
print(f"Components: {dsu.get_num_sets()}")
# Pattern 3: Group elements by component
groups = dsu.get_all_sets()
for root, members in groups.items():
print(f"Set with root {root}: {members}")💡 HPN's Rule
"Thấy 'connected', 'same group', 'merge' → DSU. Luôn dùng Path Compression + Union by Rank."