Skip to content

Segment Tree - The Range Query King

"Segment Tree là Swiss Army Knife của range queries. Master nó, giải được 80% bài competitive programming." - HPN

Problem Statement

Cho array A có N phần tử, xử lý Q queries:

  1. Query(L, R): Tính sum/min/max của A[L..R]
  2. Update(i, val): Đặt A[i] = val
A = [1, 3, 5, 7, 9, 11]

Query(1, 3) = 3 + 5 + 7 = 15
Update(2, 10)
A = [1, 3, 10, 7, 9, 11]
Query(1, 3) = 3 + 10 + 7 = 20

Why Segment Tree?

ApproachQueryUpdateQ queries on N elements
NaiveO(N)O(1)O(Q × N)
Prefix SumO(1)O(N)O(Q × N)
Segment TreeO(log N)O(log N)O(Q × log N) ✅

100K queries × 100K elements:

  • Naive: 10 billion ops → Timeout
  • Segment Tree: ~1.7 million ops → Fast

Tree Structure Visualization

Array: [1, 3, 5, 7, 9, 11] (indices 0-5)

                    [36]                    ← sum of entire array
                [0, 5]
               /      \
           [9]          [27]                ← left/right halves
         [0, 2]        [3, 5]
         /    \        /    \
      [4]     [5]   [16]    [11]           ← further splits
    [0,1]   [2,2]  [3,4]   [5,5]
    /   \           /   \
  [1]   [3]      [7]   [9]                 ← leaf nodes
 [0,0] [1,1]    [3,3] [4,4]

Each node stores:
- Sum of its range
- [L, R] = range covered

Key Properties

PropertyValue
HeightO(log N)
Number of nodes~ 2N (cần 4N để an toàn)
StorageO(N)
Each nodeStores aggregate of its range

Building the Tree

python
def build(node: int, start: int, end: int):
    """
    Recursive tree building.
    
    node: current node index
    [start, end]: range this node is responsible for
    """
    if start == end:
        # Leaf node: store element directly
        tree[node] = arr[start]
    else:
        mid = (start + end) // 2
        left_child = 2 * node + 1
        right_child = 2 * node + 2
        
        # Build children recursively
        build(left_child, start, mid)
        build(right_child, mid + 1, end)
        
        # Combine children results
        tree[node] = tree[left_child] + tree[right_child]

Query Logic

3 Cases khi query range [L, R]:

Node range: [start, end]
Query range: [L, R]

Case 1: [start, end] hoàn toàn NGOÀI [L, R]
        → Return 0 (hoặc identity element)

Case 2: [start, end] hoàn toàn TRONG [L, R]
        → Return tree[node]

Case 3: Overlap
        → Query cả 2 children, combine kết quả
python
def query(node: int, start: int, end: int, L: int, R: int):
    """
    Range query [L, R].
    """
    # Case 1: Completely outside
    if R < start or end < L:
        return 0  # Identity for sum
    
    # Case 2: Completely inside
    if L <= start and end <= R:
        return tree[node]
    
    # Case 3: Partial overlap
    mid = (start + end) // 2
    left_result = query(2*node+1, start, mid, L, R)
    right_result = query(2*node+2, mid+1, end, L, R)
    
    return left_result + right_result

Point Update

python
def update(node: int, start: int, end: int, idx: int, val: int):
    """
    Update arr[idx] = val.
    """
    if start == end:
        # Leaf node
        tree[node] = val
    else:
        mid = (start + end) // 2
        
        if idx <= mid:
            update(2*node+1, start, mid, idx, val)
        else:
            update(2*node+2, mid+1, end, idx, val)
        
        # Recalculate after child update
        tree[node] = tree[2*node+1] + tree[2*node+2]

Production Implementation

python
# HPN Engineering Standard
# Implementation: Segment Tree - Full Featured

from typing import List, Callable, TypeVar, Generic, Optional
from dataclasses import dataclass
from enum import Enum


T = TypeVar('T')


class SegmentTree(Generic[T]):
    """
    Generic Segment Tree supporting any associative operation.
    
    Time Complexity:
    - Build: O(N)
    - Query: O(log N)
    - Update: O(log N)
    
    Space: O(N)
    """
    
    def __init__(
        self, 
        data: List[T], 
        combine: Callable[[T, T], T],
        identity: T
    ):
        """
        Args:
            data: Input array
            combine: Associative function (e.g., add, min, max)
            identity: Identity element (0 for sum, inf for min, -inf for max)
        """
        self.n = len(data)
        self.combine = combine
        self.identity = identity
        self.tree = [identity] * (4 * self.n)
        self._build(data, 0, 0, self.n - 1)
    
    def _build(self, data: List[T], node: int, start: int, end: int):
        """Build tree recursively."""
        if start == end:
            self.tree[node] = data[start]
        else:
            mid = (start + end) // 2
            left, right = 2 * node + 1, 2 * node + 2
            
            self._build(data, left, start, mid)
            self._build(data, right, mid + 1, end)
            
            self.tree[node] = self.combine(self.tree[left], self.tree[right])
    
    def query(self, L: int, R: int) -> T:
        """Query range [L, R] (inclusive)."""
        return self._query(0, 0, self.n - 1, L, R)
    
    def _query(self, node: int, start: int, end: int, L: int, R: int) -> T:
        """Internal query."""
        if R < start or end < L:
            return self.identity
        
        if L <= start and end <= R:
            return self.tree[node]
        
        mid = (start + end) // 2
        left_result = self._query(2*node+1, start, mid, L, R)
        right_result = self._query(2*node+2, mid+1, end, L, R)
        
        return self.combine(left_result, right_result)
    
    def update(self, idx: int, val: T):
        """Point update: arr[idx] = val."""
        self._update(0, 0, self.n - 1, idx, val)
    
    def _update(self, node: int, start: int, end: int, idx: int, val: T):
        """Internal update."""
        if start == end:
            self.tree[node] = val
        else:
            mid = (start + end) // 2
            
            if idx <= mid:
                self._update(2*node+1, start, mid, idx, val)
            else:
                self._update(2*node+2, mid+1, end, idx, val)
            
            self.tree[node] = self.combine(
                self.tree[2*node+1], 
                self.tree[2*node+2]
            )


# ============================================
# SPECIALIZED IMPLEMENTATIONS
# ============================================

class SumSegmentTree:
    """Optimized Segment Tree for sum queries."""
    
    def __init__(self, data: List[int]):
        self.n = len(data)
        self.tree = [0] * (4 * self.n)
        self._build(data, 0, 0, self.n - 1)
    
    def _build(self, data, node, start, end):
        if start == end:
            self.tree[node] = data[start]
        else:
            mid = (start + end) // 2
            self._build(data, 2*node+1, start, mid)
            self._build(data, 2*node+2, mid+1, end)
            self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
    
    def query(self, L: int, R: int) -> int:
        return self._query(0, 0, self.n-1, L, R)
    
    def _query(self, node, start, end, L, R):
        if R < start or end < L:
            return 0
        if L <= start and end <= R:
            return self.tree[node]
        mid = (start + end) // 2
        return (self._query(2*node+1, start, mid, L, R) + 
                self._query(2*node+2, mid+1, end, L, R))
    
    def update(self, idx: int, val: int):
        self._update(0, 0, self.n-1, idx, val)
    
    def _update(self, node, start, end, idx, val):
        if start == end:
            self.tree[node] = val
        else:
            mid = (start + end) // 2
            if idx <= mid:
                self._update(2*node+1, start, mid, idx, val)
            else:
                self._update(2*node+2, mid+1, end, idx, val)
            self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]


class MinMaxSegmentTree:
    """Segment Tree tracking both min and max."""
    
    def __init__(self, data: List[int]):
        self.n = len(data)
        self.min_tree = [float('inf')] * (4 * self.n)
        self.max_tree = [float('-inf')] * (4 * self.n)
        self._build(data, 0, 0, self.n - 1)
    
    def _build(self, data, node, start, end):
        if start == end:
            self.min_tree[node] = data[start]
            self.max_tree[node] = data[start]
        else:
            mid = (start + end) // 2
            left, right = 2*node+1, 2*node+2
            self._build(data, left, start, mid)
            self._build(data, right, mid+1, end)
            self.min_tree[node] = min(self.min_tree[left], self.min_tree[right])
            self.max_tree[node] = max(self.max_tree[left], self.max_tree[right])
    
    def query_min(self, L: int, R: int) -> int:
        return self._query_min(0, 0, self.n-1, L, R)
    
    def query_max(self, L: int, R: int) -> int:
        return self._query_max(0, 0, self.n-1, L, R)
    
    def _query_min(self, node, start, end, L, R):
        if R < start or end < L:
            return float('inf')
        if L <= start and end <= R:
            return self.min_tree[node]
        mid = (start + end) // 2
        return min(self._query_min(2*node+1, start, mid, L, R),
                   self._query_min(2*node+2, mid+1, end, L, R))
    
    def _query_max(self, node, start, end, L, R):
        if R < start or end < L:
            return float('-inf')
        if L <= start and end <= R:
            return self.max_tree[node]
        mid = (start + end) // 2
        return max(self._query_max(2*node+1, start, mid, L, R),
                   self._query_max(2*node+2, mid+1, end, L, R))
    
    def update(self, idx: int, val: int):
        self._update(0, 0, self.n-1, idx, val)
    
    def _update(self, node, start, end, idx, val):
        if start == end:
            self.min_tree[node] = val
            self.max_tree[node] = val
        else:
            mid = (start + end) // 2
            if idx <= mid:
                self._update(2*node+1, start, mid, idx, val)
            else:
                self._update(2*node+2, mid+1, end, idx, val)
            left, right = 2*node+1, 2*node+2
            self.min_tree[node] = min(self.min_tree[left], self.min_tree[right])
            self.max_tree[node] = max(self.max_tree[left], self.max_tree[right])


# ============================================
# REAL-WORLD: STOCK MARKET ANALYSIS
# ============================================

@dataclass
class StockData:
    timestamp: int
    price: float


class StockAnalyzer:
    """
    Analyze stock prices using Segment Tree.
    
    Supports:
    - Min/Max price in time range
    - Price updates
    """
    
    def __init__(self, prices: List[float]):
        # Convert to int for integer tree, multiply by 100 for cents
        self.prices = [int(p * 100) for p in prices]
        self.tree = MinMaxSegmentTree(self.prices)
    
    def get_min_price(self, start_time: int, end_time: int) -> float:
        """Get minimum price in time range."""
        return self.tree.query_min(start_time, end_time) / 100
    
    def get_max_price(self, start_time: int, end_time: int) -> float:
        """Get maximum price in time range."""
        return self.tree.query_max(start_time, end_time) / 100
    
    def update_price(self, time: int, new_price: float):
        """Update price at specific time."""
        self.tree.update(time, int(new_price * 100))
    
    def get_volatility(self, start_time: int, end_time: int) -> float:
        """Get price volatility (max - min) in range."""
        min_p = self.get_min_price(start_time, end_time)
        max_p = self.get_max_price(start_time, end_time)
        return max_p - min_p


# ============================================
# USAGE EXAMPLE
# ============================================

if __name__ == "__main__":
    # Example 1: Basic Sum Segment Tree
    print("=== Sum Segment Tree ===")
    arr = [1, 3, 5, 7, 9, 11]
    st = SumSegmentTree(arr)
    
    print(f"Array: {arr}")
    print(f"Sum[1, 3] = {st.query(1, 3)}")  # 3 + 5 + 7 = 15
    print(f"Sum[0, 5] = {st.query(0, 5)}")  # 36
    
    st.update(2, 10)
    print(f"After update(2, 10):")
    print(f"Sum[1, 3] = {st.query(1, 3)}")  # 3 + 10 + 7 = 20
    
    # Example 2: Generic Segment Tree
    print("\n=== Generic Segment Tree (GCD) ===")
    from math import gcd
    gcd_tree = SegmentTree([12, 18, 24, 30, 36], gcd, 0)
    print(f"GCD[0, 4] = {gcd_tree.query(0, 4)}")  # GCD of all = 6
    print(f"GCD[1, 3] = {gcd_tree.query(1, 3)}")  # GCD(18,24,30) = 6
    
    # Example 3: Stock Analysis
    print("\n=== Stock Market Analysis ===")
    prices = [100.0, 102.5, 98.0, 105.0, 110.0, 108.0, 115.0]
    analyzer = StockAnalyzer(prices)
    
    print(f"Prices: {prices}")
    print(f"Min price [1, 4]: ${analyzer.get_min_price(1, 4)}")
    print(f"Max price [1, 4]: ${analyzer.get_max_price(1, 4)}")
    print(f"Volatility [1, 4]: ${analyzer.get_volatility(1, 4)}")
    
    analyzer.update_price(3, 120.0)
    print(f"\nAfter price update at time 3 → $120:")
    print(f"Max price [1, 4]: ${analyzer.get_max_price(1, 4)}")

Complexity Analysis

OperationTimeSpace
BuildO(N)O(N)
QueryO(log N)O(log N) stack
UpdateO(log N)O(log N) stack

Lazy Propagation (Advanced)

⚠️ When Updates are on Ranges

Nếu cần range update (update cả đoạn [L, R]), point update O(log N) × range_size = O(N).

Solution: Lazy Propagation - Defer updates until needed.

python
def range_update_lazy(node, start, end, L, R, val):
    """
    Add val to all elements in [L, R].
    Uses lazy propagation for efficiency.
    """
    # Push pending updates down
    push_down(node, start, end)
    
    if R < start or end < L:
        return
    
    if L <= start and end <= R:
        # Mark lazy update
        lazy[node] += val
        tree[node] += val * (end - start + 1)
        return
    
    mid = (start + end) // 2
    range_update_lazy(2*node+1, start, mid, L, R, val)
    range_update_lazy(2*node+2, mid+1, end, L, R, val)
    tree[node] = tree[2*node+1] + tree[2*node+2]

Common Operations

OperationCombine FunctionIdentity
Suma + b0
Minmin(a, b)+∞
Maxmax(a, b)-∞
GCDgcd(a, b)0
Producta * b1
XORa ^ b0

💡 HPN's Rule

"Segment Tree = Divide & Conquer trên array. Mỗi node lưu aggregate của range. Query/Update đi từ root xuống leaves theo O(log N) path."