Skip to content

🔍 Debugging Training

Level: Advanced Solves: Xác định và fix các vấn đề training phổ biến, từ data issues đến gradient problems

The Debugging Mindset

💡 Giáo sư Tom

Deep learning debugging là detective work. Model của bạn không "broken" - nó đang làm chính xác những gì bạn bảo nó làm. Vấn đề là bạn đang bảo nó làm sai. Systematic debugging sẽ giúp bạn tìm ra "sai" ở đâu.

┌─────────────────────────────────────────────────────────────────┐
│              DEBUGGING PRIORITY ORDER                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. DATA (60% of issues)                                        │
│     • Data loading correct?                                     │
│     • Labels correct?                                           │
│     • Preprocessing correct?                                    │
│                                                                 │
│  2. LOSS & METRICS (20% of issues)                              │
│     • Loss function appropriate?                                │
│     • Metrics computed correctly?                               │
│     • Class imbalance handled?                                  │
│                                                                 │
│  3. MODEL (15% of issues)                                       │
│     • Architecture appropriate?                                 │
│     • Initialization correct?                                   │
│     • Forward pass correct?                                     │
│                                                                 │
│  4. OPTIMIZATION (5% of issues)                                 │
│     • Learning rate appropriate?                                │
│     • Optimizer choice?                                         │
│     • Gradient flow?                                            │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Data Issues

Common Data Problems

┌─────────────────────────────────────────────────────────────────┐
│                    DATA DEBUGGING CHECKLIST                     │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  LOADING ISSUES                                                 │
│  ──────────────                                                 │
│  □ Data actually loading? (print shapes, samples)               │
│  □ Correct data type? (float32, not float64)                    │
│  □ Correct device? (GPU vs CPU)                                 │
│  □ Shuffling working? (check batch diversity)                   │
│                                                                 │
│  LABEL ISSUES                                                   │
│  ────────────                                                   │
│  □ Labels match inputs? (alignment check)                       │
│  □ Label encoding correct? (0-indexed, one-hot)                 │
│  □ No label leakage? (future info in features)                  │
│  □ Class distribution? (imbalance check)                        │
│                                                                 │
│  PREPROCESSING ISSUES                                           │
│  ────────────────────                                           │
│  □ Normalization correct? (mean/std from train only)            │
│  □ Augmentation reasonable? (visualize augmented samples)       │
│  □ Same preprocessing train/val/test?                           │
│  □ No data corruption? (NaN, Inf values)                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Data Sanity Checks

python
# Essential data checks before training
def sanity_check_data(dataloader):
    batch = next(iter(dataloader))
    x, y = batch
    
    # Shape check
    print(f"Input shape: {x.shape}")
    print(f"Label shape: {y.shape}")
    
    # Value range check
    print(f"Input range: [{x.min():.3f}, {x.max():.3f}]")
    print(f"Input mean: {x.mean():.3f}, std: {x.std():.3f}")
    
    # NaN/Inf check
    assert not torch.isnan(x).any(), "NaN in inputs!"
    assert not torch.isinf(x).any(), "Inf in inputs!"
    
    # Label distribution
    if y.dim() == 1:  # Classification
        unique, counts = torch.unique(y, return_counts=True)
        print(f"Label distribution: {dict(zip(unique.tolist(), counts.tolist()))}")
    
    # Visualize samples (for images)
    # plt.imshow(x[0].permute(1,2,0).cpu())

Label Alignment Bug

🚨 Silent Killer: Misaligned Labels

Một trong những bugs khó detect nhất là labels bị misalign với inputs. Model vẫn train, loss vẫn giảm (vì nó học noise), nhưng performance trên test set rất tệ.

python
# ❌ WRONG: Shuffling inputs but not labels
np.random.shuffle(X)  # Labels Y không được shuffle cùng!

# ✅ CORRECT: Shuffle together
indices = np.random.permutation(len(X))
X = X[indices]
Y = Y[indices]

# ✅ BETTER: Use DataLoader with shuffle=True

Loss & Metrics Issues

Loss Function Debugging

┌─────────────────────────────────────────────────────────────────┐
│                 LOSS DEBUGGING GUIDE                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  SYMPTOM                    LIKELY CAUSE                        │
│  ───────                    ────────────                        │
│                                                                 │
│  Loss = NaN                 • Exploding gradients               │
│                             • Log of zero/negative              │
│                             • Division by zero                  │
│                             • Learning rate too high            │
│                                                                 │
│  Loss = Inf                 • Numerical overflow                │
│                             • Extreme predictions               │
│                                                                 │
│  Loss stuck high            • Wrong loss function               │
│                             • Learning rate too low             │
│                             • Model too simple                  │
│                             • Data issue                        │
│                                                                 │
│  Loss oscillates wildly     • Learning rate too high            │
│                             • Batch size too small              │
│                             • Data shuffling issue              │
│                                                                 │
│  Train loss ↓, val loss ↑   • Overfitting                       │
│                             • Data leakage in train             │
│                             • Train/val distribution mismatch   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Cross-Entropy Pitfalls

python
# ❌ WRONG: Softmax + CrossEntropyLoss (double softmax!)
output = F.softmax(logits, dim=1)
loss = F.cross_entropy(output, labels)  # CE already applies softmax!

# ✅ CORRECT: Raw logits + CrossEntropyLoss
loss = F.cross_entropy(logits, labels)

# ❌ WRONG: BCELoss with logits
output = model(x)  # Raw logits
loss = F.binary_cross_entropy(output, labels)  # Expects probabilities!

# ✅ CORRECT: BCEWithLogitsLoss
loss = F.binary_cross_entropy_with_logits(output, labels)

Model Issues

The Overfit-One-Batch Test

✅ First Sanity Check

Trước khi train full dataset, hãy verify model có thể overfit một batch nhỏ. Nếu không overfit được, có bug trong model hoặc training loop.

python
def overfit_one_batch(model, dataloader, optimizer, epochs=100):
    """Model should achieve ~0 loss on one batch"""
    model.train()
    batch = next(iter(dataloader))
    x, y = batch
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(x)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            acc = (output.argmax(1) == y).float().mean()
            print(f"Epoch {epoch}: Loss={loss:.4f}, Acc={acc:.4f}")
    
    # Should see loss → 0, acc → 1.0
    # If not, there's a bug!

Common Model Bugs

┌─────────────────────────────────────────────────────────────────┐
│                 COMMON MODEL BUGS                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  BUG: Forgot model.train() / model.eval()                       │
│  ─────────────────────────────────────────                      │
│  Impact: BatchNorm/Dropout behave wrong                         │
│  Fix: Always set mode before forward pass                       │
│                                                                 │
│  BUG: Wrong input dimensions                                    │
│  ────────────────────────────                                   │
│  Impact: Silent broadcasting, wrong results                     │
│  Fix: Assert shapes at each layer                               │
│                                                                 │
│  BUG: Activation after final layer                              │
│  ─────────────────────────────────                              │
│  Impact: Limits output range incorrectly                        │
│  Fix: No activation before loss (CE expects logits)             │
│                                                                 │
│  BUG: Shared weights unintentionally                            │
│  ────────────────────────────────                               │
│  Impact: Layers update together                                 │
│  Fix: Create new layer instances, don't reuse                   │
│                                                                 │
│  BUG: Detached tensors in computation                           │
│  ─────────────────────────────────                              │
│  Impact: Gradients don't flow                                   │
│  Fix: Don't use .detach() or .data in training                  │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Gradient Issues

Gradient Debugging

python
def check_gradients(model):
    """Check gradient health after backward pass"""
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad = param.grad
            
            # Check for NaN/Inf
            if torch.isnan(grad).any():
                print(f"⚠️ NaN gradient in {name}")
            if torch.isinf(grad).any():
                print(f"⚠️ Inf gradient in {name}")
            
            # Check gradient magnitude
            grad_norm = grad.norm()
            if grad_norm < 1e-7:
                print(f"⚠️ Vanishing gradient in {name}: {grad_norm:.2e}")
            if grad_norm > 1e3:
                print(f"⚠️ Exploding gradient in {name}: {grad_norm:.2e}")
        else:
            print(f"⚠️ No gradient for {name}")

Vanishing Gradients

┌─────────────────────────────────────────────────────────────────┐
│              VANISHING GRADIENTS                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  SYMPTOMS                                                       │
│  ────────                                                       │
│  • Early layers don't update                                    │
│  • Loss decreases very slowly                                   │
│  • Gradient norms near zero                                     │
│                                                                 │
│  CAUSES                                                         │
│  ──────                                                         │
│  • Sigmoid/Tanh saturation                                      │
│  • Too many layers without skip connections                     │
│  • Poor weight initialization                                   │
│  • Very small learning rate                                     │
│                                                                 │
│  SOLUTIONS                                                      │
│  ─────────                                                      │
│  • Use ReLU/GELU activations                                    │
│  • Add skip connections (ResNet-style)                          │
│  • Use proper initialization (He/Xavier)                        │
│  • Add batch/layer normalization                                │
│  • Reduce network depth                                         │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Exploding Gradients

┌─────────────────────────────────────────────────────────────────┐
│              EXPLODING GRADIENTS                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  SYMPTOMS                                                       │
│  ────────                                                       │
│  • Loss becomes NaN                                             │
│  • Weights become very large                                    │
│  • Training becomes unstable                                    │
│                                                                 │
│  CAUSES                                                         │
│  ──────                                                         │
│  • Learning rate too high                                       │
│  • Poor initialization                                          │
│  • Deep RNNs without gradient clipping                          │
│  • Numerical instability in loss                                │
│                                                                 │
│  SOLUTIONS                                                      │
│  ─────────                                                      │
│  • Gradient clipping (clip_grad_norm_)                          │
│  • Lower learning rate                                          │
│  • Use gradient-friendly architectures (LSTM, Transformer)      │
│  • Add normalization layers                                     │
│  • Check for numerical issues in loss                           │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Silent Bugs

The Most Dangerous Bugs

🚨 Silent Bugs

Silent bugs là những bugs không gây crash nhưng làm model perform kém. Chúng là nguyên nhân của hầu hết "model không học được" issues.

┌─────────────────────────────────────────────────────────────────┐
│                 SILENT BUG CATALOG                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. BROADCASTING BUGS                                           │
│     ─────────────────                                           │
│     tensor_a (100, 1) + tensor_b (1, 100) = (100, 100)          │
│     Often unintended, causes wrong computations                 │
│                                                                 │
│  2. IN-PLACE OPERATIONS                                         │
│     ────────────────────                                        │
│     x += 1  # Can break autograd                                │
│     Use: x = x + 1                                              │
│                                                                 │
│  3. WRONG REDUCTION                                             │
│     ───────────────                                             │
│     loss.mean() vs loss.sum() - affects gradient scale          │
│                                                                 │
│  4. FORGOTTEN ZERO_GRAD                                         │
│     ─────────────────────                                       │
│     Gradients accumulate across batches                         │
│     Always: optimizer.zero_grad() before backward()             │
│                                                                 │
│  5. EVAL MODE FORGOTTEN                                         │
│     ────────────────────                                        │
│     BatchNorm/Dropout behave differently in eval                │
│     Always: model.eval() for validation                         │
│                                                                 │
│  6. DATA LEAKAGE                                                │
│     ────────────                                                │
│     Preprocessing with test data statistics                     │
│     Fit scalers on train only!                                  │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Debugging Workflow

Systematic Debugging Process

┌─────────────────────────────────────────────────────────────────┐
│              DEBUGGING WORKFLOW                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Step 1: REPRODUCE                                              │
│  ─────────────────                                              │
│  • Set random seeds                                             │
│  • Save exact configuration                                     │
│  • Ensure issue is consistent                                   │
│                                                                 │
│  Step 2: SIMPLIFY                                               │
│  ────────────────                                               │
│  • Reduce to minimal example                                    │
│  • Use tiny dataset (10-100 samples)                            │
│  • Use simple model first                                       │
│                                                                 │
│  Step 3: VERIFY COMPONENTS                                      │
│  ─────────────────────────                                      │
│  • Data loading ✓                                               │
│  • Forward pass ✓                                               │
│  • Loss computation ✓                                           │
│  • Backward pass ✓                                              │
│  • Weight update ✓                                              │
│                                                                 │
│  Step 4: COMPARE                                                │
│  ────────────────                                               │
│  • Against known working implementation                         │
│  • Against simple baseline                                      │
│  • Against expected behavior                                    │
│                                                                 │
│  Step 5: ISOLATE                                                │
│  ────────────────                                               │
│  • Binary search through code                                   │
│  • Add assertions at each step                                  │
│  • Print intermediate values                                    │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Debugging Tools

ToolPurposeWhen to Use
TensorBoardLoss/metric visualizationAlways
torch.autograd.detect_anomalyFind NaN sourceNaN debugging
torch.autograd.gradcheckVerify gradient correctnessCustom layers
Weights & BiasesExperiment trackingProduction training
PyTorch ProfilerPerformance bottlenecksSlow training

Operational Checklist

Pre-Training Sanity Checks

CheckRequiredHow to Verify
Data loading verifiedsanity_check_data() passes
Labels correctly alignedSpot-check input-label pairs
No NaN/Inf in inputsAssertion in data pipeline
Normalization stats correctMean ≈ 0, Std ≈ 1
Overfit-one-batch passesLoss → 0, Acc → 100%
Gradient flow verifiedAll params have gradients
Random seeds setReproducibility confirmed

Training Stability Checks

CheckRequiredTrigger
Loss is finiteEvery batch
Gradient norms healthyEvery N steps
No NaN in weightsEvery epoch
Validation improvingEvery epoch
Early stopping configuredBefore training

Pre-Deployment Debug Gate

CheckRequiredOwner
Training logs reviewedML Engineer
Gradient issues investigatedML Engineer
Overfitting diagnosedML Lead
Silent bugs ruled outSenior ML Engineer

📎 Cross-References