Skip to content

Fenwick Tree (Binary Indexed Tree) - The Elegant Alternative

"Fenwick Tree làm cùng việc như Segment Tree nhưng với 1/4 code và 1/2 memory." - HPN

Problem Statement

Cùng bài toán range query như Segment Tree:

  1. Prefix Sum Query: Tính sum của A[0..i]
  2. Point Update: Thay đổi A[i]

📘 Tại sao Fenwick Tree?

  • Simpler: Code ngắn hơn 4x so với Segment Tree
  • Faster: Constant factor nhỏ hơn
  • Memory: Chỉ cần O(N) thay vì 4N

Nhược điểm: Chỉ hỗ trợ operations có inverse (sum, xor). Không thể dùng cho min/max.

The Magic of Binary Representation

Fenwick Tree dựa trên Least Significant Bit (LSB) của index.

Index (1-based) | Binary | LSB | Range covered
----------------|--------|-----|---------------
1               | 0001   | 1   | [1, 1]
2               | 0010   | 2   | [1, 2]
3               | 0011   | 1   | [3, 3]
4               | 0100   | 4   | [1, 4]
5               | 0101   | 1   | [5, 5]
6               | 0110   | 2   | [5, 6]
7               | 0111   | 1   | [7, 7]
8               | 1000   | 8   | [1, 8]

LSB = i & (-i)  ← Bit trick để lấy LSB!

Tree Structure Visualization

tree[i] stores sum of range ending at i with length = LSB(i)

Array: [_, 1, 3, 5, 7, 9, 11, 13, 15]  (1-indexed, _ is unused)

                    tree[8] = sum[1..8] = 64
                         |
           ┌─────────────┴─────────────┐
      tree[4]=16                   tree[8]
      sum[1..4]                   (already counted)
           |
     ┌─────┴─────┐     ┌─────────────┴─────────────┐
 tree[2]=4   tree[4]  tree[6]=20             tree[8]
sum[1..2]  (counted)  sum[5..6]            (counted)
     |                    |
  ┌──┴──┐              ┌──┴──┐
tree[1] tree[2]    tree[5] tree[6]
 =1    (counted)    =9   (counted)

Core Operations

Query: Prefix Sum [1..i]

Idea: Đi từ i về 1, mỗi bước bỏ đi LSB.

python
def prefix_sum(i: int) -> int:
    """Sum of elements [1..i]."""
    total = 0
    while i > 0:
        total += tree[i]
        i -= i & (-i)  # Remove LSB
    return total

Example: prefix_sum(7)

i=7 (0111): total += tree[7], i = 7 - 1 = 6
i=6 (0110): total += tree[6], i = 6 - 2 = 4
i=4 (0100): total += tree[4], i = 4 - 4 = 0
Done! total = tree[7] + tree[6] + tree[4]

Update: Add delta to A[i]

Idea: Đi từ i lên N, mỗi bước cộng thêm LSB.

python
def update(i: int, delta: int):
    """Add delta to element at index i."""
    while i <= n:
        tree[i] += delta
        i += i & (-i)  # Add LSB
    return

Example: update(3, 5)

i=3 (0011): tree[3] += 5, i = 3 + 1 = 4
i=4 (0100): tree[4] += 5, i = 4 + 4 = 8
i=8 (1000): tree[8] += 5, i = 8 + 8 = 16 > n
Done!

Range Sum [L..R]

python
def range_sum(L: int, R: int) -> int:
    """Sum of elements [L..R]."""
    return prefix_sum(R) - prefix_sum(L - 1)

Complete Implementation

python
# HPN Engineering Standard
# Implementation: Fenwick Tree (Binary Indexed Tree)

from typing import List, Optional


class FenwickTree:
    """
    Fenwick Tree (Binary Indexed Tree) for efficient prefix sum queries.
    
    Supports:
    - Point update: O(log N)
    - Prefix sum query: O(log N)
    - Range sum query: O(log N)
    
    Note: 1-indexed internally for bit manipulation convenience.
    """
    
    def __init__(self, n: int):
        """Initialize empty Fenwick Tree of size n."""
        self.n = n
        self.tree = [0] * (n + 1)  # 1-indexed
    
    @classmethod
    def from_array(cls, arr: List[int]) -> 'FenwickTree':
        """
        Build Fenwick Tree from array.
        
        Time: O(N) using efficient building method.
        """
        n = len(arr)
        ft = cls(n)
        
        # Copy values first
        for i in range(n):
            ft.tree[i + 1] = arr[i]
        
        # Build tree in O(N)
        for i in range(1, n + 1):
            parent = i + (i & (-i))
            if parent <= n:
                ft.tree[parent] += ft.tree[i]
        
        return ft
    
    def update(self, i: int, delta: int):
        """
        Add delta to element at index i (0-indexed).
        
        Time: O(log N)
        """
        i += 1  # Convert to 1-indexed
        while i <= self.n:
            self.tree[i] += delta
            i += i & (-i)
    
    def set(self, i: int, val: int):
        """
        Set element at index i to val.
        
        Requires knowing current value.
        Time: O(log N)
        """
        current = self.range_sum(i, i)
        self.update(i, val - current)
    
    def prefix_sum(self, i: int) -> int:
        """
        Sum of elements [0..i] (0-indexed).
        
        Time: O(log N)
        """
        i += 1  # Convert to 1-indexed
        total = 0
        while i > 0:
            total += self.tree[i]
            i -= i & (-i)
        return total
    
    def range_sum(self, L: int, R: int) -> int:
        """
        Sum of elements [L..R] (0-indexed, inclusive).
        
        Time: O(log N)
        """
        if L == 0:
            return self.prefix_sum(R)
        return self.prefix_sum(R) - self.prefix_sum(L - 1)
    
    def find_kth(self, k: int) -> int:
        """
        Find smallest index i such that prefix_sum(i) >= k.
        
        Useful for order statistics.
        Time: O(log N)
        """
        idx = 0
        bit_mask = 1 << (self.n.bit_length() - 1)
        
        while bit_mask > 0:
            next_idx = idx + bit_mask
            if next_idx <= self.n and self.tree[next_idx] < k:
                idx = next_idx
                k -= self.tree[idx]
            bit_mask >>= 1
        
        return idx  # 0-indexed


# ============================================
# 2D FENWICK TREE
# ============================================

class FenwickTree2D:
    """
    2D Fenwick Tree for 2D range sum queries.
    
    Time: O(log N × log M) per operation
    Space: O(N × M)
    """
    
    def __init__(self, n: int, m: int):
        self.n = n
        self.m = m
        self.tree = [[0] * (m + 1) for _ in range(n + 1)]
    
    def update(self, x: int, y: int, delta: int):
        """Add delta to element at (x, y)."""
        x += 1
        while x <= self.n:
            j = y + 1
            while j <= self.m:
                self.tree[x][j] += delta
                j += j & (-j)
            x += x & (-x)
    
    def prefix_sum(self, x: int, y: int) -> int:
        """Sum of rectangle [(0,0), (x,y)]."""
        x += 1
        total = 0
        while x > 0:
            j = y + 1
            while j > 0:
                total += self.tree[x][j]
                j -= j & (-j)
            x -= x & (-x)
        return total
    
    def range_sum(self, x1: int, y1: int, x2: int, y2: int) -> int:
        """Sum of rectangle [(x1,y1), (x2,y2)]."""
        total = self.prefix_sum(x2, y2)
        if x1 > 0:
            total -= self.prefix_sum(x1 - 1, y2)
        if y1 > 0:
            total -= self.prefix_sum(x2, y1 - 1)
        if x1 > 0 and y1 > 0:
            total += self.prefix_sum(x1 - 1, y1 - 1)
        return total


# ============================================
# APPLICATIONS
# ============================================

def count_inversions(arr: List[int]) -> int:
    """
    Count inversions in array using Fenwick Tree.
    
    Inversion: pair (i, j) where i < j but arr[i] > arr[j]
    
    Time: O(N log N)
    Space: O(max(arr))
    """
    if not arr:
        return 0
    
    # Coordinate compression
    sorted_unique = sorted(set(arr))
    rank = {v: i for i, v in enumerate(sorted_unique)}
    
    n = len(sorted_unique)
    ft = FenwickTree(n)
    inversions = 0
    
    # Process from right to left
    for i in range(len(arr) - 1, -1, -1):
        r = rank[arr[i]]
        # Count elements smaller than arr[i] to the right
        if r > 0:
            inversions += ft.prefix_sum(r - 1)
        ft.update(r, 1)
    
    return inversions


def dynamic_ranking(events: List[tuple]) -> List[Optional[int]]:
    """
    Process dynamic ranking queries.
    
    Events: 
    - ('add', value): Add value to set
    - ('remove', value): Remove value from set
    - ('rank', value): Return rank of value (1-indexed)
    
    Time: O(K log N) for K events
    """
    # First pass: collect all values for coordinate compression
    values = set()
    for event in events:
        values.add(event[1])
    
    sorted_vals = sorted(values)
    rank = {v: i for i, v in enumerate(sorted_vals)}
    
    ft = FenwickTree(len(sorted_vals))
    results = []
    
    for event in events:
        op, val = event[0], event[1]
        r = rank[val]
        
        if op == 'add':
            ft.update(r, 1)
            results.append(None)
        elif op == 'remove':
            ft.update(r, -1)
            results.append(None)
        elif op == 'rank':
            # Rank = count of elements <= val
            results.append(ft.prefix_sum(r))
    
    return results


# ============================================
# USAGE EXAMPLE
# ============================================

if __name__ == "__main__":
    # Example 1: Basic Fenwick Tree
    print("=== Fenwick Tree Demo ===")
    arr = [1, 3, 5, 7, 9, 11]
    ft = FenwickTree.from_array(arr)
    
    print(f"Array: {arr}")
    print(f"Prefix sum [0..2]: {ft.prefix_sum(2)}")  # 1+3+5 = 9
    print(f"Range sum [1..4]: {ft.range_sum(1, 4)}")  # 3+5+7+9 = 24
    
    ft.update(2, 5)  # Add 5 to index 2
    print(f"After update(2, +5):")
    print(f"Range sum [1..4]: {ft.range_sum(1, 4)}")  # 3+10+7+9 = 29
    
    # Example 2: Count Inversions
    print("\n=== Count Inversions ===")
    test_arr = [8, 4, 2, 1]
    inv = count_inversions(test_arr)
    print(f"Array: {test_arr}")
    print(f"Inversions: {inv}")  # 6: (8,4), (8,2), (8,1), (4,2), (4,1), (2,1)
    
    # Example 3: 2D Fenwick Tree
    print("\n=== 2D Fenwick Tree ===")
    ft2d = FenwickTree2D(3, 3)
    matrix = [
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]
    ]
    
    for i in range(3):
        for j in range(3):
            ft2d.update(i, j, matrix[i][j])
    
    print(f"Matrix:")
    for row in matrix:
        print(f"  {row}")
    
    print(f"Sum of rectangle [(0,0), (1,1)]: {ft2d.range_sum(0, 0, 1, 1)}")  # 1+2+4+5 = 12
    print(f"Sum of rectangle [(1,1), (2,2)]: {ft2d.range_sum(1, 1, 2, 2)}")  # 5+6+8+9 = 28

Complexity Comparison

OperationSegment TreeFenwick Tree
BuildO(N)O(N)
QueryO(log N)O(log N)
UpdateO(log N)O(log N)
SpaceO(4N)O(N)
Code complexity~50 lines~15 lines

Fenwick Tree vs Segment Tree

AspectFenwick TreeSegment Tree
Sum/XOR queries✅ Perfect✅ Works
Min/Max queries❌ Cannot✅ Works
Range updates⚠️ Tricky✅ Lazy propagation
ImplementationSimpleComplex
MemoryN4N
Constant factorSmallerLarger

When to Use

💡 HPN's Rule

"Sum + Point Update = Fenwick Tree (đơn giản hơn, nhanh hơn). Mọi thứ khác = Segment Tree."