Skip to content

📈 Scaling Training

Level: Advanced Solves: Train large models efficiently với distributed strategies và cost optimization

When to Scale

💡 Giáo sư Tom

Scaling không phải lúc nào cũng là câu trả lời. Trước khi throw hardware vào problem, hãy chắc chắn bạn đã optimize single-GPU training. Một model được tune tốt trên 1 GPU thường beat một model poorly-tuned trên 8 GPUs.

┌─────────────────────────────────────────────────────────────────┐
│              SCALING DECISION FRAMEWORK                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  SCALE WHEN:                                                    │
│  ───────────                                                    │
│  • Model doesn't fit in single GPU memory                       │
│  • Training time > acceptable threshold (days/weeks)            │
│  • Need to explore large hyperparameter space                   │
│  • Production requires fast iteration cycles                    │
│                                                                 │
│  DON'T SCALE WHEN:                                              │
│  ─────────────────                                              │
│  • Haven't optimized single-GPU training                        │
│  • Data is the bottleneck, not compute                          │
│  • Model is already converging well                             │
│  • Budget doesn't justify speedup                               │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Hardware Selection

GPU Comparison

GPUMemoryFP32 TFLOPSFP16 TFLOPSBest For
RTX 309024GB35.671Research, small models
RTX 409024GB82.6165Research, inference
A100 40GB40GB19.5312Training, production
A100 80GB80GB19.5312Large models
H10080GB671979LLM training

Memory Estimation

┌─────────────────────────────────────────────────────────────────┐
│              GPU MEMORY BREAKDOWN                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Total Memory = Model + Optimizer + Activations + Batch         │
│                                                                 │
│  MODEL MEMORY                                                   │
│  ────────────                                                   │
│  • Parameters: N params × 4 bytes (FP32)                        │
│  • Gradients: N params × 4 bytes                                │
│  • Total: ~8 bytes per parameter (FP32)                         │
│                                                                 │
│  OPTIMIZER MEMORY (Adam)                                        │
│  ──────────────────────                                         │
│  • Momentum: N params × 4 bytes                                 │
│  • Variance: N params × 4 bytes                                 │
│  • Total: ~8 bytes per parameter                                │
│                                                                 │
│  ACTIVATIONS                                                    │
│  ───────────                                                    │
│  • Depends on batch size and model depth                        │
│  • Often largest memory consumer                                │
│  • Can be reduced with gradient checkpointing                   │
│                                                                 │
│  EXAMPLE: 1B parameter model                                    │
│  ─────────────────────────────                                  │
│  Model: 1B × 8 bytes = 8GB                                      │
│  Optimizer: 1B × 8 bytes = 8GB                                  │
│  Activations: ~8-16GB (varies)                                  │
│  Total: ~24-32GB minimum                                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Single-GPU Optimization

Memory Optimization Techniques

TechniqueMemory SavingsSpeed ImpactComplexity
Mixed Precision (FP16)~50%+20-50% fasterLow
Gradient Checkpointing~60-70%-20-30% slowerLow
Gradient AccumulationEnables larger effective batchMinimalLow
Activation Offloading~30-50%-10-20% slowerMedium
Model ParallelismEnables larger modelsVariesHigh

Mixed Precision Training

python
# PyTorch Automatic Mixed Precision
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()
    
    # Forward pass in FP16
    with autocast():
        output = model(batch)
        loss = criterion(output, target)
    
    # Backward pass with scaling
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Gradient Checkpointing

┌─────────────────────────────────────────────────────────────────┐
│              GRADIENT CHECKPOINTING                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  NORMAL TRAINING                                                │
│  ───────────────                                                │
│  Forward: Save all activations                                  │
│  Backward: Use saved activations                                │
│  Memory: O(n) where n = number of layers                        │
│                                                                 │
│  WITH CHECKPOINTING                                             │
│  ──────────────────                                             │
│  Forward: Save only checkpoint activations                      │
│  Backward: Recompute activations from checkpoints               │
│  Memory: O(√n) with √n checkpoints                              │
│                                                                 │
│  Trade-off: ~30% more compute for ~70% less memory              │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
python
# PyTorch gradient checkpointing
from torch.utils.checkpoint import checkpoint

class CheckpointedModel(nn.Module):
    def forward(self, x):
        # Checkpoint expensive layers
        x = checkpoint(self.layer1, x)
        x = checkpoint(self.layer2, x)
        return self.head(x)

Distributed Training

Parallelism Strategies

┌─────────────────────────────────────────────────────────────────┐
│              PARALLELISM STRATEGIES                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  DATA PARALLELISM (DP)                                          │
│  ─────────────────────                                          │
│  • Same model on each GPU                                       │
│  • Different data batches                                       │
│  • Gradients synchronized                                       │
│  • Best for: Models that fit in single GPU                      │
│                                                                 │
│  ┌─────┐  ┌─────┐  ┌─────┐  ┌─────┐                             │
│  │GPU 0│  │GPU 1│  │GPU 2│  │GPU 3│                             │
│  │Model│  │Model│  │Model│  │Model│                             │
│  │Batch│  │Batch│  │Batch│  │Batch│                             │
│  │ 0-7 │  │ 8-15│  │16-23│  │24-31│                             │
│  └──┬──┘  └──┬──┘  └──┬──┘  └──┬──┘                             │
│     └────────┴────────┴────────┘                                │
│              AllReduce Gradients                                │
│                                                                 │
│  MODEL PARALLELISM (MP)                                         │
│  ──────────────────────                                         │
│  • Model split across GPUs                                      │
│  • Same data on each GPU                                        │
│  • Best for: Models too large for single GPU                    │
│                                                                 │
│  ┌─────┐  ┌─────┐  ┌─────┐  ┌─────┐                             │
│  │GPU 0│→ │GPU 1│→ │GPU 2│→ │GPU 3│                             │
│  │Layer│  │Layer│  │Layer│  │Layer│                             │
│  │ 1-4 │  │ 5-8 │  │9-12 │  │13-16│                             │
│  └─────┘  └─────┘  └─────┘  └─────┘                             │
│                                                                 │
│  PIPELINE PARALLELISM                                           │
│  ────────────────────                                           │
│  • Model split + micro-batches                                  │
│  • Reduces bubble overhead                                      │
│  • Best for: Very large models                                  │
│                                                                 │
│  TENSOR PARALLELISM                                             │
│  ─────────────────                                              │
│  • Single layer split across GPUs                               │
│  • High communication overhead                                  │
│  • Best for: Very wide layers (LLMs)                            │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

PyTorch DDP (Distributed Data Parallel)

python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def train(rank, world_size):
    setup(rank, world_size)
    
    model = MyModel().to(rank)
    model = DDP(model, device_ids=[rank])
    
    # Use DistributedSampler for data
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, sampler=sampler)
    
    for epoch in range(epochs):
        sampler.set_epoch(epoch)  # Important for shuffling
        for batch in dataloader:
            # Training loop as usual
            ...

# Launch with torchrun
# torchrun --nproc_per_node=4 train.py

FSDP (Fully Sharded Data Parallel)

┌─────────────────────────────────────────────────────────────────┐
│              FSDP vs DDP                                        │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  DDP (Distributed Data Parallel)                                │
│  ───────────────────────────────                                │
│  • Full model replica on each GPU                               │
│  • Memory: O(model_size) per GPU                                │
│  • Communication: AllReduce gradients                           │
│                                                                 │
│  FSDP (Fully Sharded Data Parallel)                             │
│  ──────────────────────────────────                             │
│  • Model sharded across GPUs                                    │
│  • Memory: O(model_size / num_gpus) per GPU                     │
│  • Communication: AllGather params, ReduceScatter grads         │
│  • Best for: Models that don't fit in single GPU                │
│                                                                 │
│  WHEN TO USE WHAT                                               │
│  ────────────────                                               │
│  Model fits in 1 GPU → DDP                                      │
│  Model doesn't fit → FSDP or DeepSpeed ZeRO                     │
│  Very large model → FSDP + Tensor Parallelism                   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
python
# PyTorch FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Wrap model with FSDP
model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrap_policy,
    mixed_precision=MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
        buffer_dtype=torch.float16,
    ),
)

Cost Control

Cloud GPU Pricing Strategy

┌─────────────────────────────────────────────────────────────────┐
│              COST OPTIMIZATION STRATEGIES                       │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. SPOT/PREEMPTIBLE INSTANCES                                  │
│     ─────────────────────────                                   │
│     • 60-90% cheaper than on-demand                             │
│     • Can be interrupted                                        │
│     • Use checkpointing for fault tolerance                     │
│                                                                 │
│  2. RIGHT-SIZING                                                │
│     ────────────                                                │
│     • Don't use A100 for small models                           │
│     • Profile memory usage first                                │
│     • Scale horizontally if cheaper                             │
│                                                                 │
│  3. EFFICIENT TRAINING                                          │
│     ───────────────────                                         │
│     • Mixed precision (2x speedup = 50% cost)                   │
│     • Gradient accumulation (use smaller instances)             │
│     • Early stopping (don't overtrain)                          │
│                                                                 │
│  4. SCHEDULING                                                  │
│     ──────────                                                  │
│     • Train during off-peak hours                               │
│     • Use reserved instances for predictable workloads          │
│     • Batch experiments together                                │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Checkpointing for Fault Tolerance

python
# Robust checkpointing for spot instances
import os

def save_checkpoint(model, optimizer, epoch, loss, path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    # Save to temp file first, then rename (atomic operation)
    temp_path = path + '.tmp'
    torch.save(checkpoint, temp_path)
    os.rename(temp_path, path)

def load_checkpoint(model, optimizer, path):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        return checkpoint['epoch'], checkpoint['loss']
    return 0, float('inf')

Cost Estimation Formula

┌─────────────────────────────────────────────────────────────────┐
│              TRAINING COST ESTIMATION                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Total Cost = GPU Hours × Price per Hour × Overhead Factor      │
│                                                                 │
│  GPU Hours = (Dataset Size × Epochs × Time per Sample)          │
│              ÷ (Batch Size × Samples per Second)                │
│                                                                 │
│  Overhead Factor (typically 1.2-1.5):                           │
│  • Failed experiments                                           │
│  • Hyperparameter tuning                                        │
│  • Data preprocessing                                           │
│  • Debugging time                                               │
│                                                                 │
│  EXAMPLE                                                        │
│  ───────                                                        │
│  Dataset: 1M samples                                            │
│  Epochs: 10                                                     │
│  Batch size: 32                                                 │
│  Throughput: 100 samples/sec                                    │
│  GPU: A100 @ $3/hour                                            │
│                                                                 │
│  GPU Hours = (1M × 10) ÷ (100 × 3600) = 27.8 hours              │
│  Cost = 27.8 × $3 × 1.3 = ~$108                                 │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Scaling Best Practices

Scaling Checklist

┌─────────────────────────────────────────────────────────────────┐
│              SCALING CHECKLIST                                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  BEFORE SCALING                                                 │
│  ──────────────                                                 │
│  □ Single-GPU training optimized                                │
│  □ Mixed precision enabled                                      │
│  □ Data loading not bottleneck                                  │
│  □ Model architecture finalized                                 │
│  □ Hyperparameters roughly tuned                                │
│                                                                 │
│  DURING SCALING                                                 │
│  ──────────────                                                 │
│  □ Learning rate scaled with batch size                         │
│  □ Warmup steps adjusted                                        │
│  □ Gradient accumulation configured                             │
│  □ Checkpointing enabled                                        │
│  □ Monitoring in place                                          │
│                                                                 │
│  AFTER SCALING                                                  │
│  ─────────────                                                  │
│  □ Verify convergence matches single-GPU                        │
│  □ Check for numerical differences                              │
│  □ Monitor GPU utilization                                      │
│  □ Track cost vs speedup                                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Learning Rate Scaling

⚠️ Critical: Learning Rate Adjustment

Khi tăng batch size, bạn cần adjust learning rate. Rule of thumb: Linear scaling với warmup.

python
# Linear scaling rule
base_lr = 0.001
base_batch_size = 32
actual_batch_size = 256  # 8 GPUs × 32

scaled_lr = base_lr * (actual_batch_size / base_batch_size)
# scaled_lr = 0.008

# With warmup (recommended)
warmup_steps = 1000
def get_lr(step):
    if step < warmup_steps:
        return scaled_lr * (step / warmup_steps)
    return scaled_lr

Monitoring at Scale

Key Metrics to Track

MetricTargetRed Flag
GPU Utilization>80%<50%
Memory Usage80-95%>98% or <50%
ThroughputLinear scalingSub-linear
Communication Time<20% of step>40%
Loss ConvergenceMatch single-GPUDivergence

Distributed Training Debugging

python
# Debug distributed training
def debug_distributed():
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    print(f"[Rank {rank}/{world_size}] GPU: {torch.cuda.current_device()}")
    print(f"[Rank {rank}] Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    
    # Check gradient sync
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_sum = param.grad.sum()
            dist.all_reduce(grad_sum)
            if rank == 0:
                print(f"{name}: grad_sum = {grad_sum.item()}")

Operational Checklist

Pre-Scaling Gate

CheckRequiredOwnerNotes
Single-GPU training optimizedML EngineerMixed precision, data loading
Memory profiledML EngineerKnow actual requirements
Baseline metrics establishedML EngineerCompare after scaling
Cost estimate approvedML Lead + FinanceBudget allocated
Checkpointing configuredML EngineerFault tolerance

Scaling Configuration Gate

CheckRequiredOwner
Learning rate scaled correctlyML Engineer
Warmup steps adjustedML Engineer
Gradient accumulation configuredML Engineer
DistributedSampler usedML Engineer
Monitoring dashboards readyML Engineer

Cost Control Gate

CheckRequiredNotes
Spot/preemptible instances evaluated60-90% savings
Right GPU tier selectedDon't over-provision
Auto-shutdown configuredIdle instance termination
Cost alerts set upBudget threshold warnings
Training duration estimatedExpected vs actual tracking

📎 Cross-References