Skip to content

🎓 Training Foundations

Level: Foundation Solves: Hiểu sâu các fundamentals của neural network training, từ optimization đến regularization

Tại sao Training Foundations quan trọng?

💡 Giáo sư Tom

Deep learning không phải magic - nó là applied optimization. Hiểu rõ gradient descent, loss landscapes, và regularization sẽ giúp bạn debug training issues nhanh hơn 10x so với việc chỉ copy-paste code từ tutorials.

┌─────────────────────────────────────────────────────────────────┐
│              DEEP LEARNING TRAINING PIPELINE                    │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────┐    ┌─────────┐    ┌─────────┐    ┌─────────┐      │
│  │  Data   │ →  │ Forward │ →  │  Loss   │ →  │Backward │      │
│  │  Batch  │    │  Pass   │    │ Compute │    │  Pass   │      │
│  └─────────┘    └─────────┘    └─────────┘    └─────────┘      │
│       │                                            │            │
│       │         ┌─────────────────────────┐        │            │
│       └─────────│     Weight Update       │←───────┘            │
│                 │   (Optimizer Step)      │                     │
│                 └─────────────────────────┘                     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Optimization Fundamentals

Gradient Descent Variants

VariantBatch SizeProsCons
Batch GDFull datasetStable gradientsSlow, memory intensive
Stochastic GD1 sampleFast updatesNoisy gradients
Mini-batch GD32-512Balance speed/stabilityHyperparameter tuning

Modern Optimizers

┌─────────────────────────────────────────────────────────────────┐
│                    OPTIMIZER EVOLUTION                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  SGD → SGD+Momentum → RMSprop → Adam → AdamW                   │
│   │         │            │        │       │                     │
│   │         │            │        │       └─ Weight decay fix   │
│   │         │            │        └─ Momentum + Adaptive LR     │
│   │         │            └─ Adaptive learning rates             │
│   │         └─ Accelerated convergence                          │
│   └─ Basic gradient descent                                     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Optimizer Selection Guide

OptimizerBest ForLearning RateNotes
SGD+MomentumCNNs, well-tuned models0.01-0.1Often best final performance
AdamTransformers, quick prototyping1e-4 to 3e-4Good default choice
AdamWTransformers with regularization1e-4 to 3e-4Proper weight decay
LAMBLarge batch trainingScaledFor batch sizes > 8K

Learning Rate Schedules

┌─────────────────────────────────────────────────────────────────┐
│              LEARNING RATE SCHEDULE PATTERNS                    │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Step Decay          Cosine Annealing      Warmup + Decay       │
│  ──────────          ────────────────      ──────────────       │
│  LR│ ▄▄▄▄            LR│    ╭──╮           LR│   ╭────╮         │
│    │     ▄▄▄▄          │   ╱    ╲            │  ╱      ╲        │
│    │         ▄▄▄▄      │  ╱      ╲           │ ╱        ╲       │
│    │             ▄▄    │ ╱        ╲          │╱          ╲      │
│    └──────────────→    └───────────→         └────────────→     │
│         Epochs              Epochs               Epochs         │
│                                                                 │
│  Use: CNNs, stable     Use: Transformers     Use: Large models  │
│  training              fine-tuning           from scratch       │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Loss Functions

Classification Losses

LossFormulaUse Case
Cross-Entropy-Σ y·log(ŷ)Multi-class classification
Binary CE-y·log(ŷ) - (1-y)·log(1-ŷ)Binary classification
Focal Loss-α(1-ŷ)^γ·log(ŷ)Imbalanced datasets
Label SmoothingCE with soft targetsRegularization, calibration

Regression Losses

LossFormulaCharacteristics
MSE (L2)(y - ŷ)²Penalizes large errors heavily
MAE (L1)`y - ŷ
HuberL2 if small, L1 if largeBest of both worlds
Log-Coshlog(cosh(y - ŷ))Smooth approximation of Huber

Specialized Losses

┌─────────────────────────────────────────────────────────────────┐
│                  SPECIALIZED LOSS FUNCTIONS                     │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  CONTRASTIVE LEARNING                                           │
│  ────────────────────                                           │
│  • InfoNCE: Self-supervised representation learning             │
│  • Triplet Loss: Metric learning (anchor, positive, negative)   │
│  • NT-Xent: SimCLR-style contrastive loss                       │
│                                                                 │
│  SEQUENCE MODELING                                              │
│  ─────────────────                                              │
│  • CTC Loss: Speech recognition, OCR                            │
│  • Sequence CE: Language modeling                               │
│                                                                 │
│  GENERATIVE MODELS                                              │
│  ─────────────────                                              │
│  • Reconstruction Loss: Autoencoders                            │
│  • KL Divergence: VAEs                                          │
│  • Adversarial Loss: GANs                                       │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Weight Initialization

Why Initialization Matters

🚨 Bad Initialization

Sai initialization có thể dẫn đến:

  • Vanishing gradients: Weights quá nhỏ → gradients → 0
  • Exploding gradients: Weights quá lớn → gradients → ∞
  • Dead neurons: ReLU neurons stuck at 0
  • Slow convergence: Training takes forever

Initialization Strategies

MethodFormulaBest For
Xavier/GlorotU(-√(6/(n_in+n_out)), √(6/(n_in+n_out)))Tanh, Sigmoid activations
He/KaimingN(0, √(2/n_in))ReLU activations
OrthogonalQR decompositionRNNs, deep networks
LSUVLayer-sequential unit-varianceVery deep networks

Initialization by Layer Type

┌─────────────────────────────────────────────────────────────────┐
│              INITIALIZATION BY LAYER TYPE                       │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Layer Type          Recommended Init       Notes               │
│  ──────────          ────────────────       ─────               │
│                                                                 │
│  Linear + ReLU       He (Kaiming)           fan_in mode         │
│  Linear + Tanh       Xavier (Glorot)        Uniform or Normal   │
│  Conv2d + ReLU       He (Kaiming)           fan_out for backprop│
│  LSTM/GRU            Orthogonal             For recurrent weights│
│  Embedding           Normal(0, 0.02)        Or uniform small    │
│  LayerNorm           weight=1, bias=0       Standard practice   │
│  BatchNorm           weight=1, bias=0       Standard practice   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Regularization Techniques

Overview of Regularization

┌─────────────────────────────────────────────────────────────────┐
│                 REGULARIZATION TAXONOMY                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────────────┐  ┌─────────────────┐  ┌─────────────────┐  │
│  │   EXPLICIT      │  │   IMPLICIT      │  │   DATA-BASED    │  │
│  ├─────────────────┤  ├─────────────────┤  ├─────────────────┤  │
│  │ • L1/L2 penalty │  │ • Early stopping│  │ • Data augment  │  │
│  │ • Dropout       │  │ • Batch norm    │  │ • Mixup/CutMix  │  │
│  │ • Weight decay  │  │ • Noise inject  │  │ • Label smooth  │  │
│  │ • Max-norm      │  │ • Gradient clip │  │ • Curriculum    │  │
│  └─────────────────┘  └─────────────────┘  └─────────────────┘  │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Dropout

VariantDescriptionUse Case
Standard DropoutRandom zero-out during trainingFully connected layers
Spatial DropoutDrop entire feature mapsCNNs
DropConnectDrop weights instead of activationsAlternative to dropout
DropBlockDrop contiguous regionsCNNs, better than spatial

Weight Decay vs L2 Regularization

⚠️ Subtle Difference

Weight decay và L2 regularization KHÔNG giống nhau với adaptive optimizers (Adam):

  • L2 Reg: Adds λ·w² to loss → gradient includes λ·w
  • Weight Decay: Directly decays weights → w = w - lr·λ·w

AdamW implements proper weight decay, Adam với L2 reg có behavior khác!

Batch Normalization

┌─────────────────────────────────────────────────────────────────┐
│                  BATCH NORMALIZATION                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Training:                                                      │
│  ─────────                                                      │
│  1. Compute batch mean: μ_B = (1/m) Σ x_i                       │
│  2. Compute batch var:  σ²_B = (1/m) Σ (x_i - μ_B)²             │
│  3. Normalize: x̂_i = (x_i - μ_B) / √(σ²_B + ε)                 │
│  4. Scale & shift: y_i = γ·x̂_i + β                             │
│                                                                 │
│  Inference:                                                     │
│  ──────────                                                     │
│  Use running mean/var computed during training                  │
│                                                                 │
│  Benefits:                                                      │
│  • Faster training (higher learning rates)                      │
│  • Reduces internal covariate shift                             │
│  • Acts as regularization                                       │
│                                                                 │
│  Pitfalls:                                                      │
│  • Small batch sizes → noisy statistics                         │
│  • Different behavior train vs inference                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Layer Normalization vs Batch Normalization

AspectBatch NormLayer Norm
Normalizes overBatch dimensionFeature dimension
Batch size dependencyYesNo
Best forCNNsTransformers, RNNs
Inference behaviorUses running statsSame as training

Gradient Flow

Vanishing/Exploding Gradients

┌─────────────────────────────────────────────────────────────────┐
│              GRADIENT FLOW PROBLEMS                             │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  VANISHING GRADIENTS                                            │
│  ───────────────────                                            │
│  Symptoms:                                                      │
│  • Early layers don't learn                                     │
│  • Loss plateaus early                                          │
│  • Gradient norms → 0                                           │
│                                                                 │
│  Causes:                                                        │
│  • Sigmoid/Tanh saturation                                      │
│  • Too many layers                                              │
│  • Poor initialization                                          │
│                                                                 │
│  Solutions:                                                     │
│  • ReLU activations                                             │
│  • Skip connections (ResNet)                                    │
│  • Proper initialization                                        │
│  • Batch/Layer normalization                                    │
│                                                                 │
│  EXPLODING GRADIENTS                                            │
│  ───────────────────                                            │
│  Symptoms:                                                      │
│  • NaN losses                                                   │
│  • Weights → ∞                                                  │
│  • Unstable training                                            │
│                                                                 │
│  Solutions:                                                     │
│  • Gradient clipping                                            │
│  • Lower learning rate                                          │
│  • Proper initialization                                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Gradient Clipping

python
# Gradient clipping by norm (recommended)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Gradient clipping by value
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
MethodWhen to UseTypical Values
Clip by NormRNNs, Transformers1.0 - 5.0
Clip by ValueSpecific gradient control0.5 - 1.0

Training Best Practices

Hyperparameter Priority

┌─────────────────────────────────────────────────────────────────┐
│           HYPERPARAMETER TUNING PRIORITY                        │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Priority 1 (Tune First):                                       │
│  • Learning rate                                                │
│  • Batch size                                                   │
│  • Number of epochs                                             │
│                                                                 │
│  Priority 2 (Tune Second):                                      │
│  • Model architecture (depth, width)                            │
│  • Optimizer choice                                             │
│  • Learning rate schedule                                       │
│                                                                 │
│  Priority 3 (Fine-tune):                                        │
│  • Regularization strength                                      │
│  • Dropout rate                                                 │
│  • Weight decay                                                 │
│                                                                 │
│  Priority 4 (Usually defaults work):                            │
│  • Optimizer betas (Adam)                                       │
│  • Epsilon values                                               │
│  • Initialization details                                       │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Training Checklist

✅ Before Training

  1. Data: Verify data loading, augmentation, normalization
  2. Model: Check parameter count, forward pass works
  3. Loss: Verify loss computation on dummy data
  4. Overfit: Can model overfit small batch? (sanity check)
  5. Baseline: Compare against simple baseline

📎 Cross-References