Giao diện
Segment Tree - The Range Query King
"Segment Tree là Swiss Army Knife của range queries. Master nó, giải được 80% bài competitive programming." - HPN
Problem Statement
Cho array A có N phần tử, xử lý Q queries:
- Query(L, R): Tính sum/min/max của
A[L..R] - Update(i, val): Đặt
A[i] = val
A = [1, 3, 5, 7, 9, 11]
Query(1, 3) = 3 + 5 + 7 = 15
Update(2, 10)
A = [1, 3, 10, 7, 9, 11]
Query(1, 3) = 3 + 10 + 7 = 20Why Segment Tree?
| Approach | Query | Update | Q queries on N elements |
|---|---|---|---|
| Naive | O(N) | O(1) | O(Q × N) |
| Prefix Sum | O(1) | O(N) | O(Q × N) |
| Segment Tree | O(log N) | O(log N) | O(Q × log N) ✅ |
100K queries × 100K elements:
- Naive: 10 billion ops → Timeout
- Segment Tree: ~1.7 million ops → Fast
Tree Structure Visualization
Array: [1, 3, 5, 7, 9, 11] (indices 0-5)
[36] ← sum of entire array
[0, 5]
/ \
[9] [27] ← left/right halves
[0, 2] [3, 5]
/ \ / \
[4] [5] [16] [11] ← further splits
[0,1] [2,2] [3,4] [5,5]
/ \ / \
[1] [3] [7] [9] ← leaf nodes
[0,0] [1,1] [3,3] [4,4]
Each node stores:
- Sum of its range
- [L, R] = range coveredKey Properties
| Property | Value |
|---|---|
| Height | O(log N) |
| Number of nodes | ~ 2N (cần 4N để an toàn) |
| Storage | O(N) |
| Each node | Stores aggregate of its range |
Building the Tree
python
def build(node: int, start: int, end: int):
"""
Recursive tree building.
node: current node index
[start, end]: range this node is responsible for
"""
if start == end:
# Leaf node: store element directly
tree[node] = arr[start]
else:
mid = (start + end) // 2
left_child = 2 * node + 1
right_child = 2 * node + 2
# Build children recursively
build(left_child, start, mid)
build(right_child, mid + 1, end)
# Combine children results
tree[node] = tree[left_child] + tree[right_child]Query Logic
3 Cases khi query range [L, R]:
Node range: [start, end]
Query range: [L, R]
Case 1: [start, end] hoàn toàn NGOÀI [L, R]
→ Return 0 (hoặc identity element)
Case 2: [start, end] hoàn toàn TRONG [L, R]
→ Return tree[node]
Case 3: Overlap
→ Query cả 2 children, combine kết quảpython
def query(node: int, start: int, end: int, L: int, R: int):
"""
Range query [L, R].
"""
# Case 1: Completely outside
if R < start or end < L:
return 0 # Identity for sum
# Case 2: Completely inside
if L <= start and end <= R:
return tree[node]
# Case 3: Partial overlap
mid = (start + end) // 2
left_result = query(2*node+1, start, mid, L, R)
right_result = query(2*node+2, mid+1, end, L, R)
return left_result + right_resultPoint Update
python
def update(node: int, start: int, end: int, idx: int, val: int):
"""
Update arr[idx] = val.
"""
if start == end:
# Leaf node
tree[node] = val
else:
mid = (start + end) // 2
if idx <= mid:
update(2*node+1, start, mid, idx, val)
else:
update(2*node+2, mid+1, end, idx, val)
# Recalculate after child update
tree[node] = tree[2*node+1] + tree[2*node+2]Production Implementation
python
# HPN Engineering Standard
# Implementation: Segment Tree - Full Featured
from typing import List, Callable, TypeVar, Generic, Optional
from dataclasses import dataclass
from enum import Enum
T = TypeVar('T')
class SegmentTree(Generic[T]):
"""
Generic Segment Tree supporting any associative operation.
Time Complexity:
- Build: O(N)
- Query: O(log N)
- Update: O(log N)
Space: O(N)
"""
def __init__(
self,
data: List[T],
combine: Callable[[T, T], T],
identity: T
):
"""
Args:
data: Input array
combine: Associative function (e.g., add, min, max)
identity: Identity element (0 for sum, inf for min, -inf for max)
"""
self.n = len(data)
self.combine = combine
self.identity = identity
self.tree = [identity] * (4 * self.n)
self._build(data, 0, 0, self.n - 1)
def _build(self, data: List[T], node: int, start: int, end: int):
"""Build tree recursively."""
if start == end:
self.tree[node] = data[start]
else:
mid = (start + end) // 2
left, right = 2 * node + 1, 2 * node + 2
self._build(data, left, start, mid)
self._build(data, right, mid + 1, end)
self.tree[node] = self.combine(self.tree[left], self.tree[right])
def query(self, L: int, R: int) -> T:
"""Query range [L, R] (inclusive)."""
return self._query(0, 0, self.n - 1, L, R)
def _query(self, node: int, start: int, end: int, L: int, R: int) -> T:
"""Internal query."""
if R < start or end < L:
return self.identity
if L <= start and end <= R:
return self.tree[node]
mid = (start + end) // 2
left_result = self._query(2*node+1, start, mid, L, R)
right_result = self._query(2*node+2, mid+1, end, L, R)
return self.combine(left_result, right_result)
def update(self, idx: int, val: T):
"""Point update: arr[idx] = val."""
self._update(0, 0, self.n - 1, idx, val)
def _update(self, node: int, start: int, end: int, idx: int, val: T):
"""Internal update."""
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
if idx <= mid:
self._update(2*node+1, start, mid, idx, val)
else:
self._update(2*node+2, mid+1, end, idx, val)
self.tree[node] = self.combine(
self.tree[2*node+1],
self.tree[2*node+2]
)
# ============================================
# SPECIALIZED IMPLEMENTATIONS
# ============================================
class SumSegmentTree:
"""Optimized Segment Tree for sum queries."""
def __init__(self, data: List[int]):
self.n = len(data)
self.tree = [0] * (4 * self.n)
self._build(data, 0, 0, self.n - 1)
def _build(self, data, node, start, end):
if start == end:
self.tree[node] = data[start]
else:
mid = (start + end) // 2
self._build(data, 2*node+1, start, mid)
self._build(data, 2*node+2, mid+1, end)
self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
def query(self, L: int, R: int) -> int:
return self._query(0, 0, self.n-1, L, R)
def _query(self, node, start, end, L, R):
if R < start or end < L:
return 0
if L <= start and end <= R:
return self.tree[node]
mid = (start + end) // 2
return (self._query(2*node+1, start, mid, L, R) +
self._query(2*node+2, mid+1, end, L, R))
def update(self, idx: int, val: int):
self._update(0, 0, self.n-1, idx, val)
def _update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
if idx <= mid:
self._update(2*node+1, start, mid, idx, val)
else:
self._update(2*node+2, mid+1, end, idx, val)
self.tree[node] = self.tree[2*node+1] + self.tree[2*node+2]
class MinMaxSegmentTree:
"""Segment Tree tracking both min and max."""
def __init__(self, data: List[int]):
self.n = len(data)
self.min_tree = [float('inf')] * (4 * self.n)
self.max_tree = [float('-inf')] * (4 * self.n)
self._build(data, 0, 0, self.n - 1)
def _build(self, data, node, start, end):
if start == end:
self.min_tree[node] = data[start]
self.max_tree[node] = data[start]
else:
mid = (start + end) // 2
left, right = 2*node+1, 2*node+2
self._build(data, left, start, mid)
self._build(data, right, mid+1, end)
self.min_tree[node] = min(self.min_tree[left], self.min_tree[right])
self.max_tree[node] = max(self.max_tree[left], self.max_tree[right])
def query_min(self, L: int, R: int) -> int:
return self._query_min(0, 0, self.n-1, L, R)
def query_max(self, L: int, R: int) -> int:
return self._query_max(0, 0, self.n-1, L, R)
def _query_min(self, node, start, end, L, R):
if R < start or end < L:
return float('inf')
if L <= start and end <= R:
return self.min_tree[node]
mid = (start + end) // 2
return min(self._query_min(2*node+1, start, mid, L, R),
self._query_min(2*node+2, mid+1, end, L, R))
def _query_max(self, node, start, end, L, R):
if R < start or end < L:
return float('-inf')
if L <= start and end <= R:
return self.max_tree[node]
mid = (start + end) // 2
return max(self._query_max(2*node+1, start, mid, L, R),
self._query_max(2*node+2, mid+1, end, L, R))
def update(self, idx: int, val: int):
self._update(0, 0, self.n-1, idx, val)
def _update(self, node, start, end, idx, val):
if start == end:
self.min_tree[node] = val
self.max_tree[node] = val
else:
mid = (start + end) // 2
if idx <= mid:
self._update(2*node+1, start, mid, idx, val)
else:
self._update(2*node+2, mid+1, end, idx, val)
left, right = 2*node+1, 2*node+2
self.min_tree[node] = min(self.min_tree[left], self.min_tree[right])
self.max_tree[node] = max(self.max_tree[left], self.max_tree[right])
# ============================================
# REAL-WORLD: STOCK MARKET ANALYSIS
# ============================================
@dataclass
class StockData:
timestamp: int
price: float
class StockAnalyzer:
"""
Analyze stock prices using Segment Tree.
Supports:
- Min/Max price in time range
- Price updates
"""
def __init__(self, prices: List[float]):
# Convert to int for integer tree, multiply by 100 for cents
self.prices = [int(p * 100) for p in prices]
self.tree = MinMaxSegmentTree(self.prices)
def get_min_price(self, start_time: int, end_time: int) -> float:
"""Get minimum price in time range."""
return self.tree.query_min(start_time, end_time) / 100
def get_max_price(self, start_time: int, end_time: int) -> float:
"""Get maximum price in time range."""
return self.tree.query_max(start_time, end_time) / 100
def update_price(self, time: int, new_price: float):
"""Update price at specific time."""
self.tree.update(time, int(new_price * 100))
def get_volatility(self, start_time: int, end_time: int) -> float:
"""Get price volatility (max - min) in range."""
min_p = self.get_min_price(start_time, end_time)
max_p = self.get_max_price(start_time, end_time)
return max_p - min_p
# ============================================
# USAGE EXAMPLE
# ============================================
if __name__ == "__main__":
# Example 1: Basic Sum Segment Tree
print("=== Sum Segment Tree ===")
arr = [1, 3, 5, 7, 9, 11]
st = SumSegmentTree(arr)
print(f"Array: {arr}")
print(f"Sum[1, 3] = {st.query(1, 3)}") # 3 + 5 + 7 = 15
print(f"Sum[0, 5] = {st.query(0, 5)}") # 36
st.update(2, 10)
print(f"After update(2, 10):")
print(f"Sum[1, 3] = {st.query(1, 3)}") # 3 + 10 + 7 = 20
# Example 2: Generic Segment Tree
print("\n=== Generic Segment Tree (GCD) ===")
from math import gcd
gcd_tree = SegmentTree([12, 18, 24, 30, 36], gcd, 0)
print(f"GCD[0, 4] = {gcd_tree.query(0, 4)}") # GCD of all = 6
print(f"GCD[1, 3] = {gcd_tree.query(1, 3)}") # GCD(18,24,30) = 6
# Example 3: Stock Analysis
print("\n=== Stock Market Analysis ===")
prices = [100.0, 102.5, 98.0, 105.0, 110.0, 108.0, 115.0]
analyzer = StockAnalyzer(prices)
print(f"Prices: {prices}")
print(f"Min price [1, 4]: ${analyzer.get_min_price(1, 4)}")
print(f"Max price [1, 4]: ${analyzer.get_max_price(1, 4)}")
print(f"Volatility [1, 4]: ${analyzer.get_volatility(1, 4)}")
analyzer.update_price(3, 120.0)
print(f"\nAfter price update at time 3 → $120:")
print(f"Max price [1, 4]: ${analyzer.get_max_price(1, 4)}")Complexity Analysis
| Operation | Time | Space |
|---|---|---|
| Build | O(N) | O(N) |
| Query | O(log N) | O(log N) stack |
| Update | O(log N) | O(log N) stack |
Lazy Propagation (Advanced)
⚠️ When Updates are on Ranges
Nếu cần range update (update cả đoạn [L, R]), point update O(log N) × range_size = O(N).
Solution: Lazy Propagation - Defer updates until needed.
python
def range_update_lazy(node, start, end, L, R, val):
"""
Add val to all elements in [L, R].
Uses lazy propagation for efficiency.
"""
# Push pending updates down
push_down(node, start, end)
if R < start or end < L:
return
if L <= start and end <= R:
# Mark lazy update
lazy[node] += val
tree[node] += val * (end - start + 1)
return
mid = (start + end) // 2
range_update_lazy(2*node+1, start, mid, L, R, val)
range_update_lazy(2*node+2, mid+1, end, L, R, val)
tree[node] = tree[2*node+1] + tree[2*node+2]Common Operations
| Operation | Combine Function | Identity |
|---|---|---|
| Sum | a + b | 0 |
| Min | min(a, b) | +∞ |
| Max | max(a, b) | -∞ |
| GCD | gcd(a, b) | 0 |
| Product | a * b | 1 |
| XOR | a ^ b | 0 |
💡 HPN's Rule
"Segment Tree = Divide & Conquer trên array. Mỗi node lưu aggregate của range. Query/Update đi từ root xuống leaves theo O(log N) path."