Skip to content

🎯 Model Selection

Level: Core Solves: Chọn model phù hợp với bài toán, tránh over-engineering và under-engineering

Philosophy: Baseline First

💡 Giáo sư Tom

"Start simple, add complexity only when needed." Một logistic regression có thể beat neural network nếu data không đủ lớn hoặc problem không đủ phức tạp. Model phức tạp không phải lúc nào cũng tốt hơn.

┌─────────────────────────────────────────────────────────────────┐
│              MODEL COMPLEXITY LADDER                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Complexity ▲                                                   │
│             │                                                   │
│             │  ┌─────────────────────────────────────────────┐  │
│             │  │ Deep Learning (Transformers, CNNs)          │  │
│             │  │ • Millions of parameters                    │  │
│             │  │ • Requires massive data                     │  │
│             │  └─────────────────────────────────────────────┘  │
│             │                                                   │
│             │  ┌─────────────────────────────────────────────┐  │
│             │  │ Ensemble Methods (XGBoost, LightGBM)        │  │
│             │  │ • Hundreds of trees                         │  │
│             │  │ • Good for tabular data                     │  │
│             │  └─────────────────────────────────────────────┘  │
│             │                                                   │
│             │  ┌─────────────────────────────────────────────┐  │
│             │  │ Single Models (Random Forest, SVM)          │  │
│             │  │ • Moderate complexity                       │  │
│             │  │ • Good interpretability                     │  │
│             │  └─────────────────────────────────────────────┘  │
│             │                                                   │
│             │  ┌─────────────────────────────────────────────┐  │
│  START HERE │  │ Simple Baselines (Logistic, Linear)         │  │
│      ▼      │  │ • Few parameters                            │  │
│             │  │ • Fast, interpretable                       │  │
│             │  └─────────────────────────────────────────────┘  │
│             │                                                   │
│             └──────────────────────────────────────────────────►│
│                              Data Size / Problem Complexity     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Baseline Models

Why Baselines Matter

ReasonExplanation
Sanity checkEnsures problem is learnable
BenchmarkMeasures improvement of complex models
Production fallbackSimple model as backup
DebuggingEasier to understand failures
SpeedFast iteration in early stages

Essential Baselines

┌─────────────────────────────────────────────────────────────────┐
│              BASELINE MODEL HIERARCHY                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Level 0: Naive Baselines (Always start here!)                  │
│  ─────────────────────────────────────────────                  │
│  • Classification: Predict most frequent class                  │
│  • Regression: Predict mean/median                              │
│  • Time series: Predict last value (naive forecast)             │
│  • Ranking: Random ordering                                     │
│                                                                 │
│  Level 1: Simple ML Baselines                                   │
│  ────────────────────────────                                   │
│  • Classification: Logistic Regression                          │
│  • Regression: Linear Regression, Ridge                         │
│  • Time series: ARIMA, Exponential Smoothing                    │
│  • Ranking: Pointwise regression                                │
│                                                                 │
│  Level 2: Moderate Complexity                                   │
│  ──────────────────────────                                     │
│  • Random Forest                                                │
│  • Gradient Boosting (XGBoost, LightGBM)                        │
│  • SVM with RBF kernel                                          │
│                                                                 │
│  Level 3: High Complexity (only if needed)                      │
│  ─────────────────────────────────────────                      │
│  • Neural Networks                                              │
│  • Deep Learning architectures                                  │
│  • Ensemble of ensembles                                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Baseline Implementation

python
from sklearn.dummy import DummyClassifier, DummyRegressor

# Classification baseline
baseline_clf = DummyClassifier(strategy='most_frequent')
baseline_clf.fit(X_train, y_train)
baseline_accuracy = baseline_clf.score(X_test, y_test)
print(f"Baseline accuracy: {baseline_accuracy:.3f}")

# Regression baseline
baseline_reg = DummyRegressor(strategy='mean')
baseline_reg.fit(X_train, y_train)
baseline_rmse = np.sqrt(mean_squared_error(y_test, baseline_reg.predict(X_test)))
print(f"Baseline RMSE: {baseline_rmse:.3f}")

# Your model must beat these baselines!

Model Selection Framework

Decision Tree for Model Selection

┌─────────────────────────────────────────────────────────────────┐
│              MODEL SELECTION DECISION TREE                      │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│                    What's your data type?                       │
│                          │                                      │
│         ┌────────────────┼────────────────┐                     │
│         │                │                │                     │
│      TABULAR          IMAGE/TEXT       TIME SERIES              │
│         │                │                │                     │
│         ▼                ▼                ▼                     │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐              │
│  │ Gradient    │  │ Deep        │  │ Statistical │              │
│  │ Boosting    │  │ Learning    │  │ + ML        │              │
│  │ (XGBoost,   │  │ (CNN, BERT, │  │ (ARIMA,     │              │
│  │ LightGBM)   │  │ Transformers│  │ Prophet,    │              │
│  │             │  │             │  │ LSTM)       │              │
│  └─────────────┘  └─────────────┘  └─────────────┘              │
│                                                                 │
│  Need interpretability?                                         │
│         │                                                       │
│    ┌────┴────┐                                                  │
│   YES        NO                                                 │
│    │          │                                                 │
│    ▼          ▼                                                 │
│  Linear    Ensemble/                                            │
│  Models    Neural Nets                                          │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
)

Model Comparison Matrix

ModelInterpretabilityTraining SpeedInference SpeedData Size Needed
Logistic Regression⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐Small
Decision Tree⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐Small
Random Forest⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐Medium
XGBoost/LightGBM⭐⭐⭐⭐⭐⭐⭐⭐⭐Medium
SVM⭐⭐⭐⭐⭐⭐⭐Medium
Neural Network⭐⭐⭐Large
Deep Learning⭐⭐Very Large

Classical ML vs Deep Learning

When to Use Classical ML

📊 Classical ML Sweet Spot

Classical ML (tree-based models, linear models) thường tốt hơn khi:

ConditionWhy Classical ML
Tabular dataTrees handle mixed types well
< 100K samplesNot enough data for deep learning
Need interpretabilityFeature importance, coefficients
Limited computeFaster training, no GPU needed
Quick iterationFaster experimentation

When to Use Deep Learning

🧠 Deep Learning Sweet Spot

Deep Learning thường tốt hơn khi:

ConditionWhy Deep Learning
Image/Video dataCNNs excel at spatial patterns
Text/NLPTransformers understand language
Audio/SpeechSequential patterns
> 1M samplesEnough data to learn
Complex patternsNon-linear relationships
Transfer learningPre-trained models available

The Boundary Decision

┌─────────────────────────────────────────────────────────────────┐
│           CLASSICAL ML vs DEEP LEARNING BOUNDARY                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  Performance                                                    │
│      ▲                                                          │
│      │                          ╱ Deep Learning                 │
│      │                        ╱                                 │
│      │                      ╱                                   │
│      │                    ╱                                     │
│      │        ──────────╱                                       │
│      │      ╱          │                                        │
│      │    ╱            │  Classical ML                          │
│      │  ╱              │  (plateaus earlier)                    │
│      │╱                │                                        │
│      └─────────────────┴────────────────────────────────────►   │
│                    ~100K              Data Size                 │
│                                                                 │
│  Key Insight:                                                   │
│  • Small data: Classical ML often wins                          │
│  • Large data: Deep Learning scales better                      │
│  • Tabular data: Classical ML usually wins regardless of size   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Model Selection by Problem Type

Classification

python
# Recommended order for classification
models_to_try = [
    # Level 1: Baselines
    ('Logistic Regression', LogisticRegression(max_iter=1000)),
    
    # Level 2: Tree-based
    ('Random Forest', RandomForestClassifier(n_estimators=100)),
    ('XGBoost', XGBClassifier(n_estimators=100, use_label_encoder=False)),
    ('LightGBM', LGBMClassifier(n_estimators=100)),
    
    # Level 3: If needed
    ('Neural Network', MLPClassifier(hidden_layer_sizes=(100, 50))),
]

# Quick comparison
for name, model in models_to_try:
    scores = cross_val_score(model, X, y, cv=5, scoring='roc_auc')
    print(f"{name}: AUC = {scores.mean():.3f} ± {scores.std():.3f}")

Regression

Problem CharacteristicRecommended Model
Linear relationshipLinear/Ridge Regression
Non-linear, tabularXGBoost, LightGBM
Many outliersHuber Regression, Quantile Regression
Uncertainty neededBayesian methods, Quantile Regression
Time seriesARIMA, Prophet, LSTM

Ranking/Recommendation

ApproachModelUse Case
PointwiseRegression on relevanceSimple ranking
PairwiseRankNet, LambdaRankSearch ranking
ListwiseListNet, LambdaMARTFull list optimization
CollaborativeMatrix Factorization, ALSUser-item recommendations
Content-basedSimilarity modelsItem features available

Hyperparameter Tuning

Tuning Strategy

┌─────────────────────────────────────────────────────────────────┐
│              HYPERPARAMETER TUNING STRATEGY                     │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. START WITH DEFAULTS                                         │
│     Most libraries have good defaults                           │
│     Baseline performance without tuning                         │
│                                                                 │
│  2. COARSE SEARCH                                               │
│     Random search over wide ranges                              │
│     Identify promising regions                                  │
│                                                                 │
│  3. FINE SEARCH                                                 │
│     Grid search or Bayesian optimization                        │
│     Narrow down to optimal values                               │
│                                                                 │
│  4. VALIDATE                                                    │
│     Cross-validation on final hyperparameters                   │
│     Check for overfitting to validation set                     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Key Hyperparameters by Model

ModelKey HyperparametersTypical Range
Logistic RegressionC (regularization)0.001 - 100
Random Forestn_estimators, max_depth100-500, 5-30
XGBoostlearning_rate, max_depth, n_estimators0.01-0.3, 3-10, 100-1000
LightGBMnum_leaves, learning_rate20-100, 0.01-0.3
Neural Networklearning_rate, layers, dropout1e-4-1e-2, architecture

Optuna Example

python
import optuna

def objective(trial):
    params = {
        'n_estimators': trial.suggest_int('n_estimators', 100, 1000),
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
    }
    
    model = XGBClassifier(**params, use_label_encoder=False, eval_metric='logloss')
    scores = cross_val_score(model, X_train, y_train, cv=5, scoring='roc_auc')
    return scores.mean()

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)

print(f"Best AUC: {study.best_value:.3f}")
print(f"Best params: {study.best_params}")

Ensemble Methods

Ensemble Strategies

┌─────────────────────────────────────────────────────────────────┐
│              ENSEMBLE STRATEGIES                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  1. BAGGING (Bootstrap Aggregating)                             │
│     ─────────────────────────────                               │
│     • Train multiple models on bootstrap samples                │
│     • Average predictions (regression) or vote (classification) │
│     • Example: Random Forest                                    │
│     • Reduces variance                                          │
│                                                                 │
│  2. BOOSTING                                                    │
│     ────────                                                    │
│     • Train models sequentially                                 │
│     • Each model corrects previous errors                       │
│     • Example: XGBoost, LightGBM, AdaBoost                      │
│     • Reduces bias                                              │
│                                                                 │
│  3. STACKING                                                    │
│     ────────                                                    │
│     • Train diverse base models                                 │
│     • Meta-model learns to combine predictions                  │
│     • Most powerful but complex                                 │
│     • Risk of overfitting                                       │
│                                                                 │
│  4. BLENDING                                                    │
│     ────────                                                    │
│     • Simple weighted average of predictions                    │
│     • Weights tuned on validation set                           │
│     • Easy to implement                                         │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Stacking Example

python
from sklearn.ensemble import StackingClassifier

# Base models (diverse is better)
base_models = [
    ('lr', LogisticRegression(max_iter=1000)),
    ('rf', RandomForestClassifier(n_estimators=100)),
    ('xgb', XGBClassifier(n_estimators=100, use_label_encoder=False)),
]

# Meta-model
stacking_clf = StackingClassifier(
    estimators=base_models,
    final_estimator=LogisticRegression(),
    cv=5
)

stacking_clf.fit(X_train, y_train)

Best Practices

Model Selection Checklist

StepActionWhy
1Establish baselineKnow minimum acceptable performance
2Start simpleLogistic/Linear regression first
3Try tree-basedXGBoost/LightGBM for tabular
4Consider constraintsLatency, interpretability, compute
5Tune hyperparametersOnly after model selection
6Validate properlyCross-validation, holdout test
7Document decisionWhy this model?

Common Mistakes

🚨 Model Selection Anti-Patterns

  • Jumping to deep learning: Without trying simpler models
  • Ignoring baselines: No benchmark for comparison
  • Over-tuning: Spending weeks on 0.1% improvement
  • Wrong metric: Optimizing for accuracy on imbalanced data
  • Leaky validation: Using test set for model selection

📎 Cross-References