Giao diện
⚡ Inference Optimization
Level: Ops Solves: Optimize model inference cho production với latency thấp và throughput cao
The Inference Challenge
💡 Giáo sư Tom
Training và inference là hai thế giới khác nhau. Training cần throughput cao, inference cần latency thấp. Một model train tốt có thể hoàn toàn unusable trong production nếu không được optimize cho inference.
┌─────────────────────────────────────────────────────────────────┐
│ TRAINING vs INFERENCE │
├─────────────────────────────────────────────────────────────────┤
│ │
│ TRAINING INFERENCE │
│ ──────── ───────── │
│ • Throughput-focused • Latency-focused │
│ • Large batches • Often batch=1 │
│ • GPU clusters • Edge devices, CPUs │
│ • FP32/FP16 • INT8/INT4 │
│ • Memory abundant • Memory constrained │
│ • Offline • Real-time │
│ │
│ OPTIMIZATION GOALS │
│ ────────────────── │
│ • Reduce latency (p50, p99) │
│ • Increase throughput (QPS) │
│ • Reduce memory footprint │
│ • Reduce cost per inference │
│ • Maintain accuracy │
│ │
└─────────────────────────────────────────────────────────────────┘Optimization Techniques Overview
┌─────────────────────────────────────────────────────────────────┐
│ OPTIMIZATION TECHNIQUES │
├─────────────────────────────────────────────────────────────────┤
│ │
│ TECHNIQUE SPEEDUP ACCURACY LOSS COMPLEXITY │
│ ───────── ─────── ───────────── ────────── │
│ Graph Optimization 1.2-2x None Low │
│ Quantization (INT8) 2-4x 0.1-1% Medium │
│ Pruning 1.5-3x 0.5-2% Medium │
│ Knowledge Distill. 2-10x 1-3% High │
│ Hardware Specific 2-5x None Medium │
│ │
│ COMBINATION STRATEGY │
│ ──────────────────── │
│ 1. Graph optimization (always) │
│ 2. Quantization (most impact) │
│ 3. Pruning (if needed) │
│ 4. Distillation (for extreme compression) │
│ │
└─────────────────────────────────────────────────────────────────┘Quantization
Quantization Fundamentals
┌─────────────────────────────────────────────────────────────────┐
│ QUANTIZATION BASICS │
├─────────────────────────────────────────────────────────────────┤
│ │
│ WHAT IS QUANTIZATION? │
│ ───────────────────── │
│ Converting weights/activations from FP32 to lower precision │
│ │
│ FP32 (32 bits) → FP16 (16 bits) → INT8 (8 bits) → INT4 (4 bits)│
│ │
│ PRECISION COMPARISON │
│ ──────────────────── │
│ FP32: 1 sign + 8 exp + 23 mantissa = 32 bits │
│ FP16: 1 sign + 5 exp + 10 mantissa = 16 bits │
│ INT8: 8 bits, range [-128, 127] or [0, 255] │
│ INT4: 4 bits, range [-8, 7] or [0, 15] │
│ │
│ MEMORY SAVINGS │
│ ────────────── │
│ FP32 → FP16: 2x smaller │
│ FP32 → INT8: 4x smaller │
│ FP32 → INT4: 8x smaller │
│ │
└─────────────────────────────────────────────────────────────────┘Quantization Types
| Type | When Applied | Accuracy | Speed | Use Case |
|---|---|---|---|---|
| Post-Training (PTQ) | After training | Good | Fast | Quick deployment |
| Quantization-Aware (QAT) | During training | Best | Slow | Production models |
| Dynamic Quantization | At runtime | Good | Medium | Variable inputs |
Post-Training Quantization
python
# PyTorch Post-Training Quantization
import torch.quantization as quant
# Prepare model
model.eval()
model.qconfig = quant.get_default_qconfig('fbgemm') # For x86
# model.qconfig = quant.get_default_qconfig('qnnpack') # For ARM
# Fuse layers (Conv+BN+ReLU)
model_fused = quant.fuse_modules(model, [['conv', 'bn', 'relu']])
# Prepare for quantization
model_prepared = quant.prepare(model_fused)
# Calibrate with representative data
with torch.no_grad():
for batch in calibration_loader:
model_prepared(batch)
# Convert to quantized model
model_quantized = quant.convert(model_prepared)Quantization-Aware Training
python
# QAT for better accuracy
model.train()
model.qconfig = quant.get_default_qat_qconfig('fbgemm')
# Prepare for QAT
model_qat = quant.prepare_qat(model)
# Train with fake quantization
for epoch in range(num_epochs):
for batch in train_loader:
output = model_qat(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Convert to quantized
model_qat.eval()
model_quantized = quant.convert(model_qat)Pruning
Pruning Strategies
┌─────────────────────────────────────────────────────────────────┐
│ PRUNING STRATEGIES │
├─────────────────────────────────────────────────────────────────┤
│ │
│ UNSTRUCTURED PRUNING │
│ ──────────────────── │
│ • Remove individual weights │
│ • High sparsity possible (90%+) │
│ • Requires sparse hardware/libraries │
│ • Example: Remove weights with |w| < threshold │
│ │
│ STRUCTURED PRUNING │
│ ────────────────── │
│ • Remove entire neurons/channels/layers │
│ • Lower sparsity (30-70%) │
│ • Works with standard hardware │
│ • Example: Remove channels with low L1 norm │
│ │
│ WHEN TO USE WHAT │
│ ──────────────── │
│ Unstructured: Specialized hardware (sparse accelerators) │
│ Structured: Standard GPUs/CPUs │
│ │
└─────────────────────────────────────────────────────────────────┘Pruning Implementation
python
import torch.nn.utils.prune as prune
# Unstructured pruning (L1)
prune.l1_unstructured(model.conv1, name='weight', amount=0.3)
# Structured pruning (remove channels)
prune.ln_structured(model.conv1, name='weight', amount=0.3, n=2, dim=0)
# Global pruning (across all layers)
parameters_to_prune = [
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4,
)
# Make pruning permanent
prune.remove(model.conv1, 'weight')Graph Optimization
Common Optimizations
┌─────────────────────────────────────────────────────────────────┐
│ GRAPH OPTIMIZATIONS │
├─────────────────────────────────────────────────────────────────┤
│ │
│ OPERATOR FUSION │
│ ─────────────── │
│ Conv → BN → ReLU → FusedConvBNReLU │
│ Reduces memory bandwidth, kernel launch overhead │
│ │
│ CONSTANT FOLDING │
│ ──────────────── │
│ Pre-compute operations on constants at compile time │
│ Example: Fold BN parameters into Conv weights │
│ │
│ DEAD CODE ELIMINATION │
│ ───────────────────── │
│ Remove unused operations and tensors │
│ │
│ LAYOUT OPTIMIZATION │
│ ─────────────────── │
│ NCHW vs NHWC - choose optimal for hardware │
│ Reduce transpose operations │
│ │
└─────────────────────────────────────────────────────────────────┘TorchScript Optimization
python
# Convert to TorchScript for optimization
model.eval()
# Tracing (for models without control flow)
traced_model = torch.jit.trace(model, example_input)
# Scripting (for models with control flow)
scripted_model = torch.jit.script(model)
# Optimize for inference
optimized_model = torch.jit.optimize_for_inference(traced_model)
# Save for deployment
optimized_model.save("model_optimized.pt")Hardware-Specific Optimization
ONNX Runtime
python
import onnxruntime as ort
# Export to ONNX
torch.onnx.export(
model,
example_input,
"model.onnx",
opset_version=13,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}}
)
# Create optimized session
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(
"model.onnx",
sess_options,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
# Run inference
output = session.run(None, {'input': input_data})TensorRT (NVIDIA)
┌─────────────────────────────────────────────────────────────────┐
│ TENSORRT OPTIMIZATION │
├─────────────────────────────────────────────────────────────────┤
│ │
│ WHAT TENSORRT DOES │
│ ────────────────── │
│ • Layer fusion │
│ • Kernel auto-tuning │
│ • Precision calibration (FP16/INT8) │
│ • Memory optimization │
│ │
│ TYPICAL SPEEDUP │
│ ─────────────── │
│ • FP32: 2-3x faster than PyTorch │
│ • FP16: 4-6x faster │
│ • INT8: 6-10x faster │
│ │
│ WORKFLOW │
│ ──────── │
│ PyTorch → ONNX → TensorRT Engine → Deploy │
│ │
└─────────────────────────────────────────────────────────────────┘Batching Strategies
Dynamic Batching
┌─────────────────────────────────────────────────────────────────┐
│ BATCHING STRATEGIES │
├─────────────────────────────────────────────────────────────────┤
│ │
│ NO BATCHING (batch=1) │
│ ───────────────────── │
│ • Lowest latency per request │
│ • Lowest throughput │
│ • GPU underutilized │
│ │
│ STATIC BATCHING │
│ ─────────────── │
│ • Fixed batch size │
│ • Wait for batch to fill │
│ • Higher latency, higher throughput │
│ │
│ DYNAMIC BATCHING │
│ ──────────────── │
│ • Batch requests within time window │
│ • Balance latency vs throughput │
│ • Best for production │
│ │
│ CONTINUOUS BATCHING (LLMs) │
│ ────────────────────────── │
│ • Add/remove requests mid-batch │
│ • Optimal for variable-length sequences │
│ • Used by vLLM, TGI │
│ │
└─────────────────────────────────────────────────────────────────┘Latency Profiling
Profiling Tools
python
# PyTorch Profiler
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
) as prof:
for _ in range(10):
model(input_data)
# Print summary
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# Export for visualization
prof.export_chrome_trace("trace.json")Latency Breakdown
┌─────────────────────────────────────────────────────────────────┐
│ LATENCY BREAKDOWN │
├─────────────────────────────────────────────────────────────────┤
│ │
│ TYPICAL INFERENCE LATENCY COMPONENTS │
│ ──────────────────────────────────── │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Data Transfer (CPU→GPU) │████░░░░░░│ 15% │ │
│ │ Preprocessing │██░░░░░░░░│ 10% │ │
│ │ Model Forward Pass │████████░░│ 60% │ │
│ │ Postprocessing │█░░░░░░░░░│ 5% │ │
│ │ Data Transfer (GPU→CPU) │██░░░░░░░░│ 10% │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ OPTIMIZATION PRIORITY │
│ ───────────────────── │
│ 1. Model forward pass (quantization, pruning) │
│ 2. Data transfer (pinned memory, async) │
│ 3. Pre/post processing (GPU acceleration) │
│ │
└─────────────────────────────────────────────────────────────────┘Production Deployment Checklist
┌─────────────────────────────────────────────────────────────────┐
│ DEPLOYMENT CHECKLIST │
├─────────────────────────────────────────────────────────────────┤
│ │
│ OPTIMIZATION │
│ ──────────── │
│ □ Model converted to inference format (TorchScript/ONNX) │
│ □ Quantization applied and validated │
│ □ Graph optimizations enabled │
│ □ Batch size tuned for target hardware │
│ │
│ VALIDATION │
│ ────────── │
│ □ Accuracy verified post-optimization │
│ □ Latency meets SLA (p50, p99) │
│ □ Memory usage within limits │
│ □ Throughput meets requirements │
│ │
│ MONITORING │
│ ────────── │
│ □ Latency metrics instrumented │
│ □ Error rate tracking │
│ □ Resource utilization monitoring │
│ □ Model drift detection │
│ │
└─────────────────────────────────────────────────────────────────┘✅ Operational Checklist
Pre-Optimization Gate
| Check | Required | Owner | Notes |
|---|---|---|---|
| Baseline latency measured | ✓ | ML Engineer | p50, p95, p99 |
| Target SLO defined | ✓ | Product + ML Lead | e.g., p99 < 100ms |
| Accuracy threshold set | ✓ | ML Lead | Max acceptable degradation |
| Hardware target identified | ✓ | Platform | GPU/CPU/Edge |
Optimization Validation Gate
| Check | Required | How to Verify |
|---|---|---|
| Accuracy within threshold | ✓ | Compare to baseline on holdout |
| Latency meets SLO | ✓ | Load test with realistic traffic |
| Memory within limits | ✓ | Profile on target hardware |
| Throughput meets requirements | ✓ | Stress test at peak load |
| Edge cases handled | ✓ | Test with max-length inputs |
Production Deployment Gate
| Check | Required | Owner |
|---|---|---|
| Model converted to deployment format | ✓ | ML Engineer |
| Quantization tested | ✓ | ML Engineer |
| Batching strategy configured | ✓ | ML Engineer |
| Fallback mechanism ready | ✓ | ML Lead |
| Latency monitoring instrumented | ✓ | Platform |
| SLO alerting configured | ✓ | Platform |
📎 Cross-References
- 📎 Model Ops - Model registry và versioning
- 📎 ML Deployment - Deployment patterns
- 📎 ML Monitoring - Production monitoring
- 📎 ML Governance - Production readiness
- 📎 Architectures Overview - Model architectures
- 📎 LLM Cost Optimization - LLM-specific optimization