Skip to content

Disjoint Set Union (DSU) - The Connected Kingdom

"Union-Find là thuật toán đơn giản nhất mà mạnh nhất. Master nó, và bạn sẽ thấy nó ở khắp nơi." - HPN

Problem Statement

Quản lý một tập hợp các phần tử chia thành nhiều nhóm không giao nhau (disjoint sets).

Hai thao tác cơ bản:

  1. Find(x): Tìm nhóm (representative) của phần tử x
  2. Union(x, y): Hợp nhất nhóm chứa x và nhóm chứa y
Initial: {0}, {1}, {2}, {3}, {4}

Union(0, 1): {0, 1}, {2}, {3}, {4}
Union(2, 3): {0, 1}, {2, 3}, {4}
Union(0, 3): {0, 1, 2, 3}, {4}

Find(2) == Find(1)?  → YES (cùng nhóm)
Find(4) == Find(0)?  → NO (khác nhóm)

Real-World Applications

DomainUse CaseChi tiết
Social NetworksFriends of FriendsTìm connected components
Image ProcessingPixel clusteringSegmentation
NetworkingNetwork connectivityKiểm tra 2 nodes có thể communicate không
Kruskal's MSTMinimum Spanning TreeDetect cycles khi thêm edges
Game DevelopmentDynamic connectivityTerrain/region management
CompilerType inferenceUnifying type variables

Visualization

Naive Implementation (Slow)

python
class NaiveDSU:
    """
    ❌ SLOW - O(N) per operation.
    Chỉ để hiểu concept, KHÔNG dùng trong production!
    """
    
    def __init__(self, n: int):
        self.parent = list(range(n))
    
    def find(self, x: int) -> int:
        """Find root by following parent pointers."""
        while self.parent[x] != x:
            x = self.parent[x]
        return x
    
    def union(self, x: int, y: int):
        """Union by connecting roots."""
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x != root_y:
            self.parent[root_x] = root_y

Vấn đề: Cây có thể trở nên rất dài (skewed tree) → O(N) cho mỗi find!

Worst case:
0 → 1 → 2 → 3 → 4 → 5 → ... → N-1

Two Key Optimizations

1. Path Compression

Idea: Khi find, flatten tree bằng cách nối trực tiếp tất cả nodes về root.

Before find(0):          After find(0):
    4                         4
    ↑                       ↗ ↑ ↖
    3                      0  3  (others)
    ↑                         ↑
    2                         2
    ↑                         ↑
    1                         1

    0

One traversal → Tree becomes flat!
python
def find(self, x: int) -> int:
    """Find with path compression."""
    if self.parent[x] != x:
        self.parent[x] = self.find(self.parent[x])  # Recursive compression
    return self.parent[x]

2. Union by Rank/Size

Idea: Khi union, luôn nối cây nhỏ hơn vào cây lớn hơn.

  • Union by Rank: Dựa trên chiều cao cây
  • Union by Size: Dựa trên số nodes
python
def union(self, x: int, y: int) -> bool:
    """Union by rank. Returns True if union happened."""
    root_x = self.find(x)
    root_y = self.find(y)
    
    if root_x == root_y:
        return False  # Already in same set
    
    # Attach smaller tree under larger tree
    if self.rank[root_x] < self.rank[root_y]:
        self.parent[root_x] = root_y
    elif self.rank[root_x] > self.rank[root_y]:
        self.parent[root_y] = root_x
    else:
        self.parent[root_y] = root_x
        self.rank[root_x] += 1
    
    return True

Optimized Implementation

python
from typing import List, Tuple, Optional


class DSU:
    """
    Disjoint Set Union with Path Compression + Union by Rank.
    
    Time Complexity per operation: O(α(N)) ≈ O(1) amortized
    where α is inverse Ackermann function (grows EXTREMELY slowly)
    
    For practical purposes: O(1) per operation!
    """
    
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.size = [1] * n  # Track size of each set
        self.num_sets = n    # Number of disjoint sets
    
    def find(self, x: int) -> int:
        """
        Find representative with path compression.
        
        Time: O(α(N)) amortized ≈ O(1)
        """
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x: int, y: int) -> bool:
        """
        Union by rank.
        
        Returns: True if union happened (different sets), False if same set.
        """
        root_x = self.find(x)
        root_y = self.find(y)
        
        if root_x == root_y:
            return False
        
        # Union by rank
        if self.rank[root_x] < self.rank[root_y]:
            root_x, root_y = root_y, root_x
        
        self.parent[root_y] = root_x
        self.size[root_x] += self.size[root_y]
        
        if self.rank[root_x] == self.rank[root_y]:
            self.rank[root_x] += 1
        
        self.num_sets -= 1
        return True
    
    def connected(self, x: int, y: int) -> bool:
        """Check if x and y are in the same set."""
        return self.find(x) == self.find(y)
    
    def get_size(self, x: int) -> int:
        """Get size of set containing x."""
        return self.size[self.find(x)]
    
    def get_num_sets(self) -> int:
        """Get number of disjoint sets."""
        return self.num_sets

Production Code

python
# HPN Engineering Standard
# Implementation: Disjoint Set Union - Full Featured

from typing import List, Tuple, Dict, Set, Optional
from collections import defaultdict


class DSU:
    """
    Production-ready DSU with all features.
    
    Optimizations:
    - Path compression
    - Union by rank
    
    Features:
    - Get all sets
    - Set size tracking
    - Undo operations (optional)
    """
    
    def __init__(self, n: int):
        self.n = n
        self.parent = list(range(n))
        self.rank = [0] * n
        self.size = [1] * n
        self.num_sets = n
    
    def find(self, x: int) -> int:
        """Find with path compression."""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x: int, y: int) -> bool:
        """Union by rank. Returns True if merged."""
        px, py = self.find(x), self.find(y)
        if px == py:
            return False
        
        if self.rank[px] < self.rank[py]:
            px, py = py, px
        
        self.parent[py] = px
        self.size[px] += self.size[py]
        
        if self.rank[px] == self.rank[py]:
            self.rank[px] += 1
        
        self.num_sets -= 1
        return True
    
    def connected(self, x: int, y: int) -> bool:
        return self.find(x) == self.find(y)
    
    def get_size(self, x: int) -> int:
        return self.size[self.find(x)]
    
    def get_all_sets(self) -> Dict[int, List[int]]:
        """Get all sets as {root: [members]}."""
        sets = defaultdict(list)
        for i in range(self.n):
            sets[self.find(i)].append(i)
        return dict(sets)


# ============================================
# APPLICATION: KRUSKAL'S MST
# ============================================

def kruskal_mst(
    n: int, 
    edges: List[Tuple[int, int, float]]
) -> Tuple[float, List[Tuple[int, int, float]]]:
    """
    Kruskal's Minimum Spanning Tree using DSU.
    
    Args:
        n: Number of vertices
        edges: [(u, v, weight), ...]
    
    Returns:
        (total_weight, mst_edges)
    
    Time: O(E log E) for sorting + O(E α(V)) for DSU
    """
    # Sort edges by weight
    sorted_edges = sorted(edges, key=lambda e: e[2])
    
    dsu = DSU(n)
    mst_edges = []
    total_weight = 0
    
    for u, v, weight in sorted_edges:
        if dsu.union(u, v):
            mst_edges.append((u, v, weight))
            total_weight += weight
            
            # MST has exactly V-1 edges
            if len(mst_edges) == n - 1:
                break
    
    return total_weight, mst_edges


# ============================================
# APPLICATION: CONNECTED COMPONENTS
# ============================================

def find_connected_components(
    n: int, 
    edges: List[Tuple[int, int]]
) -> List[List[int]]:
    """
    Find all connected components in an undirected graph.
    
    Time: O(V + E × α(V)) ≈ O(V + E)
    """
    dsu = DSU(n)
    
    for u, v in edges:
        dsu.union(u, v)
    
    return list(dsu.get_all_sets().values())


def is_graph_connected(n: int, edges: List[Tuple[int, int]]) -> bool:
    """Check if graph is connected."""
    if n <= 1:
        return True
    
    dsu = DSU(n)
    for u, v in edges:
        dsu.union(u, v)
    
    return dsu.get_num_sets() == 1


# ============================================
# APPLICATION: SOCIAL NETWORK (Friends of Friends)
# ============================================

def find_friend_groups(
    users: List[str],
    friendships: List[Tuple[str, str]]
) -> List[List[str]]:
    """
    Find friend groups in social network.
    
    Two users are in same group if they're friends,
    or friends of friends, etc.
    """
    user_to_id = {user: i for i, user in enumerate(users)}
    n = len(users)
    
    dsu = DSU(n)
    
    for u, v in friendships:
        dsu.union(user_to_id[u], user_to_id[v])
    
    groups = dsu.get_all_sets()
    return [[users[i] for i in members] for members in groups.values()]


def count_mutual_friends(
    users: List[str],
    friendships: List[Tuple[str, str]],
    user1: str,
    user2: str
) -> int:
    """
    Check if two users are in same friend network.
    Returns group size if connected, 0 if not.
    """
    user_to_id = {user: i for i, user in enumerate(users)}
    dsu = DSU(len(users))
    
    for u, v in friendships:
        dsu.union(user_to_id[u], user_to_id[v])
    
    id1, id2 = user_to_id[user1], user_to_id[user2]
    
    if dsu.connected(id1, id2):
        return dsu.get_size(id1)
    return 0


# ============================================
# APPLICATION: DYNAMIC CONNECTIVITY
# ============================================

class DynamicConnectivity:
    """
    Track connectivity as edges are added dynamically.
    
    ⚠️ Note: Standard DSU only supports ADD, not DELETE.
    For delete support, need more complex data structures.
    """
    
    def __init__(self, n: int):
        self.dsu = DSU(n)
        self.edge_count = 0
    
    def add_edge(self, u: int, v: int) -> bool:
        """
        Add edge. Returns True if it connects two components.
        """
        self.edge_count += 1
        return self.dsu.union(u, v)
    
    def is_connected(self, u: int, v: int) -> bool:
        return self.dsu.connected(u, v)
    
    def get_component_count(self) -> int:
        return self.dsu.get_num_sets()
    
    def is_fully_connected(self) -> bool:
        return self.dsu.get_num_sets() == 1


# ============================================
# USAGE EXAMPLE
# ============================================

if __name__ == "__main__":
    # Example 1: Basic DSU
    print("=== Basic DSU ===")
    dsu = DSU(5)
    dsu.union(0, 1)
    dsu.union(2, 3)
    dsu.union(1, 3)
    
    print(f"0 and 2 connected? {dsu.connected(0, 2)}")  # True
    print(f"0 and 4 connected? {dsu.connected(0, 4)}")  # False
    print(f"Number of sets: {dsu.get_num_sets()}")      # 2
    print(f"All sets: {dsu.get_all_sets()}")
    
    # Example 2: Kruskal's MST
    print("\n=== Kruskal's MST ===")
    edges = [
        (0, 1, 4), (0, 7, 8), (1, 2, 8), (1, 7, 11),
        (2, 3, 7), (2, 5, 4), (2, 8, 2), (3, 4, 9),
        (3, 5, 14), (4, 5, 10), (5, 6, 2), (6, 7, 1),
        (6, 8, 6), (7, 8, 7)
    ]
    total, mst = kruskal_mst(9, edges)
    print(f"MST weight: {total}")
    print(f"MST edges: {mst}")
    
    # Example 3: Social Network
    print("\n=== Friend Groups ===")
    users = ["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]
    friendships = [
        ("Alice", "Bob"),
        ("Bob", "Charlie"),
        ("David", "Eve"),
    ]
    
    groups = find_friend_groups(users, friendships)
    for i, group in enumerate(groups):
        print(f"  Group {i+1}: {group}")
    
    # Example 4: Connected Components
    print("\n=== Graph Connectivity ===")
    n = 6
    graph_edges = [(0, 1), (1, 2), (3, 4)]
    print(f"Is connected? {is_graph_connected(n, graph_edges)}")
    print(f"Components: {find_connected_components(n, graph_edges)}")

Complexity Analysis

OperationNaiveWith Both Optimizations
FindO(N)O(α(N))O(1)
UnionO(N)O(α(N))O(1)
M operationsO(M×N)O(M×α(N))

📘 Inverse Ackermann Function α(N)

α(N)<5 cho mọi N thực tế (kể cả số atoms trong vũ trụ).

Trong thực tế: Coi như O(1)!

When to Use DSU

Problem TypeUse DSU?Alternative
Dynamic connectivity✅ Yes-
Static connected components✅ YesBFS/DFS cũng được
Kruskal's MST✅ Yes-
Cycle detection (undirected)✅ YesDFS
Need to DELETE edges❌ NoLink-Cut Tree
Directed graph connectivity❌ NoTarjan's SCC

Common Patterns

python
# Pattern 1: Check if adding edge creates cycle
if dsu.connected(u, v):
    print("Adding this edge creates a cycle!")
else:
    dsu.union(u, v)

# Pattern 2: Count components after all operations
dsu = DSU(n)
for u, v in edges:
    dsu.union(u, v)
print(f"Components: {dsu.get_num_sets()}")

# Pattern 3: Group elements by component
groups = dsu.get_all_sets()
for root, members in groups.items():
    print(f"Set with root {root}: {members}")

💡 HPN's Rule

"Thấy 'connected', 'same group', 'merge' → DSU. Luôn dùng Path Compression + Union by Rank."