Giao diện
📈 Scaling Training
Level: Advanced Solves: Train large models efficiently với distributed strategies và cost optimization
When to Scale
💡 Giáo sư Tom
Scaling không phải lúc nào cũng là câu trả lời. Trước khi throw hardware vào problem, hãy chắc chắn bạn đã optimize single-GPU training. Một model được tune tốt trên 1 GPU thường beat một model poorly-tuned trên 8 GPUs.
┌─────────────────────────────────────────────────────────────────┐
│ SCALING DECISION FRAMEWORK │
├─────────────────────────────────────────────────────────────────┤
│ │
│ SCALE WHEN: │
│ ─────────── │
│ • Model doesn't fit in single GPU memory │
│ • Training time > acceptable threshold (days/weeks) │
│ • Need to explore large hyperparameter space │
│ • Production requires fast iteration cycles │
│ │
│ DON'T SCALE WHEN: │
│ ───────────────── │
│ • Haven't optimized single-GPU training │
│ • Data is the bottleneck, not compute │
│ • Model is already converging well │
│ • Budget doesn't justify speedup │
│ │
└─────────────────────────────────────────────────────────────────┘Hardware Selection
GPU Comparison
| GPU | Memory | FP32 TFLOPS | FP16 TFLOPS | Best For |
|---|---|---|---|---|
| RTX 3090 | 24GB | 35.6 | 71 | Research, small models |
| RTX 4090 | 24GB | 82.6 | 165 | Research, inference |
| A100 40GB | 40GB | 19.5 | 312 | Training, production |
| A100 80GB | 80GB | 19.5 | 312 | Large models |
| H100 | 80GB | 67 | 1979 | LLM training |
Memory Estimation
┌─────────────────────────────────────────────────────────────────┐
│ GPU MEMORY BREAKDOWN │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Total Memory = Model + Optimizer + Activations + Batch │
│ │
│ MODEL MEMORY │
│ ──────────── │
│ • Parameters: N params × 4 bytes (FP32) │
│ • Gradients: N params × 4 bytes │
│ • Total: ~8 bytes per parameter (FP32) │
│ │
│ OPTIMIZER MEMORY (Adam) │
│ ────────────────────── │
│ • Momentum: N params × 4 bytes │
│ • Variance: N params × 4 bytes │
│ • Total: ~8 bytes per parameter │
│ │
│ ACTIVATIONS │
│ ─────────── │
│ • Depends on batch size and model depth │
│ • Often largest memory consumer │
│ • Can be reduced with gradient checkpointing │
│ │
│ EXAMPLE: 1B parameter model │
│ ───────────────────────────── │
│ Model: 1B × 8 bytes = 8GB │
│ Optimizer: 1B × 8 bytes = 8GB │
│ Activations: ~8-16GB (varies) │
│ Total: ~24-32GB minimum │
│ │
└─────────────────────────────────────────────────────────────────┘Single-GPU Optimization
Memory Optimization Techniques
| Technique | Memory Savings | Speed Impact | Complexity |
|---|---|---|---|
| Mixed Precision (FP16) | ~50% | +20-50% faster | Low |
| Gradient Checkpointing | ~60-70% | -20-30% slower | Low |
| Gradient Accumulation | Enables larger effective batch | Minimal | Low |
| Activation Offloading | ~30-50% | -10-20% slower | Medium |
| Model Parallelism | Enables larger models | Varies | High |
Mixed Precision Training
python
# PyTorch Automatic Mixed Precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in FP16
with autocast():
output = model(batch)
loss = criterion(output, target)
# Backward pass with scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()Gradient Checkpointing
┌─────────────────────────────────────────────────────────────────┐
│ GRADIENT CHECKPOINTING │
├─────────────────────────────────────────────────────────────────┤
│ │
│ NORMAL TRAINING │
│ ─────────────── │
│ Forward: Save all activations │
│ Backward: Use saved activations │
│ Memory: O(n) where n = number of layers │
│ │
│ WITH CHECKPOINTING │
│ ────────────────── │
│ Forward: Save only checkpoint activations │
│ Backward: Recompute activations from checkpoints │
│ Memory: O(√n) with √n checkpoints │
│ │
│ Trade-off: ~30% more compute for ~70% less memory │
│ │
└─────────────────────────────────────────────────────────────────┘python
# PyTorch gradient checkpointing
from torch.utils.checkpoint import checkpoint
class CheckpointedModel(nn.Module):
def forward(self, x):
# Checkpoint expensive layers
x = checkpoint(self.layer1, x)
x = checkpoint(self.layer2, x)
return self.head(x)Distributed Training
Parallelism Strategies
┌─────────────────────────────────────────────────────────────────┐
│ PARALLELISM STRATEGIES │
├─────────────────────────────────────────────────────────────────┤
│ │
│ DATA PARALLELISM (DP) │
│ ───────────────────── │
│ • Same model on each GPU │
│ • Different data batches │
│ • Gradients synchronized │
│ • Best for: Models that fit in single GPU │
│ │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │GPU 0│ │GPU 1│ │GPU 2│ │GPU 3│ │
│ │Model│ │Model│ │Model│ │Model│ │
│ │Batch│ │Batch│ │Batch│ │Batch│ │
│ │ 0-7 │ │ 8-15│ │16-23│ │24-31│ │
│ └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘ │
│ └────────┴────────┴────────┘ │
│ AllReduce Gradients │
│ │
│ MODEL PARALLELISM (MP) │
│ ────────────────────── │
│ • Model split across GPUs │
│ • Same data on each GPU │
│ • Best for: Models too large for single GPU │
│ │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │GPU 0│→ │GPU 1│→ │GPU 2│→ │GPU 3│ │
│ │Layer│ │Layer│ │Layer│ │Layer│ │
│ │ 1-4 │ │ 5-8 │ │9-12 │ │13-16│ │
│ └─────┘ └─────┘ └─────┘ └─────┘ │
│ │
│ PIPELINE PARALLELISM │
│ ──────────────────── │
│ • Model split + micro-batches │
│ • Reduces bubble overhead │
│ • Best for: Very large models │
│ │
│ TENSOR PARALLELISM │
│ ───────────────── │
│ • Single layer split across GPUs │
│ • High communication overhead │
│ • Best for: Very wide layers (LLMs) │
│ │
└─────────────────────────────────────────────────────────────────┘PyTorch DDP (Distributed Data Parallel)
python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train(rank, world_size):
setup(rank, world_size)
model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])
# Use DistributedSampler for data
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, sampler=sampler)
for epoch in range(epochs):
sampler.set_epoch(epoch) # Important for shuffling
for batch in dataloader:
# Training loop as usual
...
# Launch with torchrun
# torchrun --nproc_per_node=4 train.pyFSDP (Fully Sharded Data Parallel)
┌─────────────────────────────────────────────────────────────────┐
│ FSDP vs DDP │
├─────────────────────────────────────────────────────────────────┤
│ │
│ DDP (Distributed Data Parallel) │
│ ─────────────────────────────── │
│ • Full model replica on each GPU │
│ • Memory: O(model_size) per GPU │
│ • Communication: AllReduce gradients │
│ │
│ FSDP (Fully Sharded Data Parallel) │
│ ────────────────────────────────── │
│ • Model sharded across GPUs │
│ • Memory: O(model_size / num_gpus) per GPU │
│ • Communication: AllGather params, ReduceScatter grads │
│ • Best for: Models that don't fit in single GPU │
│ │
│ WHEN TO USE WHAT │
│ ──────────────── │
│ Model fits in 1 GPU → DDP │
│ Model doesn't fit → FSDP or DeepSpeed ZeRO │
│ Very large model → FSDP + Tensor Parallelism │
│ │
└─────────────────────────────────────────────────────────────────┘python
# PyTorch FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Wrap model with FSDP
model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrap_policy,
mixed_precision=MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
),
)Cost Control
Cloud GPU Pricing Strategy
┌─────────────────────────────────────────────────────────────────┐
│ COST OPTIMIZATION STRATEGIES │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 1. SPOT/PREEMPTIBLE INSTANCES │
│ ───────────────────────── │
│ • 60-90% cheaper than on-demand │
│ • Can be interrupted │
│ • Use checkpointing for fault tolerance │
│ │
│ 2. RIGHT-SIZING │
│ ──────────── │
│ • Don't use A100 for small models │
│ • Profile memory usage first │
│ • Scale horizontally if cheaper │
│ │
│ 3. EFFICIENT TRAINING │
│ ─────────────────── │
│ • Mixed precision (2x speedup = 50% cost) │
│ • Gradient accumulation (use smaller instances) │
│ • Early stopping (don't overtrain) │
│ │
│ 4. SCHEDULING │
│ ────────── │
│ • Train during off-peak hours │
│ • Use reserved instances for predictable workloads │
│ • Batch experiments together │
│ │
└─────────────────────────────────────────────────────────────────┘Checkpointing for Fault Tolerance
python
# Robust checkpointing for spot instances
import os
def save_checkpoint(model, optimizer, epoch, loss, path):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
# Save to temp file first, then rename (atomic operation)
temp_path = path + '.tmp'
torch.save(checkpoint, temp_path)
os.rename(temp_path, path)
def load_checkpoint(model, optimizer, path):
if os.path.exists(path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['loss']
return 0, float('inf')Cost Estimation Formula
┌─────────────────────────────────────────────────────────────────┐
│ TRAINING COST ESTIMATION │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Total Cost = GPU Hours × Price per Hour × Overhead Factor │
│ │
│ GPU Hours = (Dataset Size × Epochs × Time per Sample) │
│ ÷ (Batch Size × Samples per Second) │
│ │
│ Overhead Factor (typically 1.2-1.5): │
│ • Failed experiments │
│ • Hyperparameter tuning │
│ • Data preprocessing │
│ • Debugging time │
│ │
│ EXAMPLE │
│ ─────── │
│ Dataset: 1M samples │
│ Epochs: 10 │
│ Batch size: 32 │
│ Throughput: 100 samples/sec │
│ GPU: A100 @ $3/hour │
│ │
│ GPU Hours = (1M × 10) ÷ (100 × 3600) = 27.8 hours │
│ Cost = 27.8 × $3 × 1.3 = ~$108 │
│ │
└─────────────────────────────────────────────────────────────────┘Scaling Best Practices
Scaling Checklist
┌─────────────────────────────────────────────────────────────────┐
│ SCALING CHECKLIST │
├─────────────────────────────────────────────────────────────────┤
│ │
│ BEFORE SCALING │
│ ────────────── │
│ □ Single-GPU training optimized │
│ □ Mixed precision enabled │
│ □ Data loading not bottleneck │
│ □ Model architecture finalized │
│ □ Hyperparameters roughly tuned │
│ │
│ DURING SCALING │
│ ────────────── │
│ □ Learning rate scaled with batch size │
│ □ Warmup steps adjusted │
│ □ Gradient accumulation configured │
│ □ Checkpointing enabled │
│ □ Monitoring in place │
│ │
│ AFTER SCALING │
│ ───────────── │
│ □ Verify convergence matches single-GPU │
│ □ Check for numerical differences │
│ □ Monitor GPU utilization │
│ □ Track cost vs speedup │
│ │
└─────────────────────────────────────────────────────────────────┘Learning Rate Scaling
⚠️ Critical: Learning Rate Adjustment
Khi tăng batch size, bạn cần adjust learning rate. Rule of thumb: Linear scaling với warmup.
python
# Linear scaling rule
base_lr = 0.001
base_batch_size = 32
actual_batch_size = 256 # 8 GPUs × 32
scaled_lr = base_lr * (actual_batch_size / base_batch_size)
# scaled_lr = 0.008
# With warmup (recommended)
warmup_steps = 1000
def get_lr(step):
if step < warmup_steps:
return scaled_lr * (step / warmup_steps)
return scaled_lrMonitoring at Scale
Key Metrics to Track
| Metric | Target | Red Flag |
|---|---|---|
| GPU Utilization | >80% | <50% |
| Memory Usage | 80-95% | >98% or <50% |
| Throughput | Linear scaling | Sub-linear |
| Communication Time | <20% of step | >40% |
| Loss Convergence | Match single-GPU | Divergence |
Distributed Training Debugging
python
# Debug distributed training
def debug_distributed():
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"[Rank {rank}/{world_size}] GPU: {torch.cuda.current_device()}")
print(f"[Rank {rank}] Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB")
# Check gradient sync
for name, param in model.named_parameters():
if param.grad is not None:
grad_sum = param.grad.sum()
dist.all_reduce(grad_sum)
if rank == 0:
print(f"{name}: grad_sum = {grad_sum.item()}")✅ Operational Checklist
Pre-Scaling Gate
| Check | Required | Owner | Notes |
|---|---|---|---|
| Single-GPU training optimized | ✓ | ML Engineer | Mixed precision, data loading |
| Memory profiled | ✓ | ML Engineer | Know actual requirements |
| Baseline metrics established | ✓ | ML Engineer | Compare after scaling |
| Cost estimate approved | ✓ | ML Lead + Finance | Budget allocated |
| Checkpointing configured | ✓ | ML Engineer | Fault tolerance |
Scaling Configuration Gate
| Check | Required | Owner |
|---|---|---|
| Learning rate scaled correctly | ✓ | ML Engineer |
| Warmup steps adjusted | ✓ | ML Engineer |
| Gradient accumulation configured | ✓ | ML Engineer |
| DistributedSampler used | ✓ | ML Engineer |
| Monitoring dashboards ready | ✓ | ML Engineer |
Cost Control Gate
| Check | Required | Notes |
|---|---|---|
| Spot/preemptible instances evaluated | ✓ | 60-90% savings |
| Right GPU tier selected | ✓ | Don't over-provision |
| Auto-shutdown configured | ✓ | Idle instance termination |
| Cost alerts set up | ✓ | Budget threshold warnings |
| Training duration estimated | ✓ | Expected vs actual tracking |
📎 Cross-References
- 📎 Training Foundations - Optimization basics
- 📎 Debugging Training - Troubleshooting distributed issues
- 📎 ML Deployment - Production deployment patterns
- 📎 ML Governance - Cost governance
- 📎 AWS Compute - Cloud GPU options
- 📎 GCP Compute - GCP GPU instances