Skip to content

Inference Optimization

Level: Ops Solves: Optimize model inference cho production với latency thấp và throughput cao

The Inference Challenge

💡 Giáo sư Tom

Training và inference là hai thế giới khác nhau. Training cần throughput cao, inference cần latency thấp. Một model train tốt có thể hoàn toàn unusable trong production nếu không được optimize cho inference.

┌─────────────────────────────────────────────────────────────────┐
│              TRAINING vs INFERENCE                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  TRAINING                      INFERENCE                        │
│  ────────                      ─────────                        │
│  • Throughput-focused          • Latency-focused                │
│  • Large batches               • Often batch=1                  │
│  • GPU clusters                • Edge devices, CPUs             │
│  • FP32/FP16                   • INT8/INT4                      │
│  • Memory abundant             • Memory constrained             │
│  • Offline                     • Real-time                      │
│                                                                 │
│  OPTIMIZATION GOALS                                             │
│  ──────────────────                                             │
│  • Reduce latency (p50, p99)                                    │
│  • Increase throughput (QPS)                                    │
│  • Reduce memory footprint                                      │
│  • Reduce cost per inference                                    │
│  • Maintain accuracy                                            │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Optimization Techniques Overview

┌─────────────────────────────────────────────────────────────────┐
│              OPTIMIZATION TECHNIQUES                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  TECHNIQUE          SPEEDUP    ACCURACY LOSS    COMPLEXITY      │
│  ─────────          ───────    ─────────────    ──────────      │
│  Graph Optimization  1.2-2x    None             Low             │
│  Quantization (INT8) 2-4x      0.1-1%           Medium          │
│  Pruning             1.5-3x    0.5-2%           Medium          │
│  Knowledge Distill.  2-10x     1-3%             High            │
│  Hardware Specific   2-5x      None             Medium          │
│                                                                 │
│  COMBINATION STRATEGY                                           │
│  ────────────────────                                           │
│  1. Graph optimization (always)                                 │
│  2. Quantization (most impact)                                  │
│  3. Pruning (if needed)                                         │
│  4. Distillation (for extreme compression)                      │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Quantization

Quantization Fundamentals

┌─────────────────────────────────────────────────────────────────┐
│              QUANTIZATION BASICS                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  WHAT IS QUANTIZATION?                                          │
│  ─────────────────────                                          │
│  Converting weights/activations from FP32 to lower precision    │
│                                                                 │
│  FP32 (32 bits) → FP16 (16 bits) → INT8 (8 bits) → INT4 (4 bits)│
│                                                                 │
│  PRECISION COMPARISON                                           │
│  ────────────────────                                           │
│  FP32: 1 sign + 8 exp + 23 mantissa = 32 bits                   │
│  FP16: 1 sign + 5 exp + 10 mantissa = 16 bits                   │
│  INT8: 8 bits, range [-128, 127] or [0, 255]                    │
│  INT4: 4 bits, range [-8, 7] or [0, 15]                         │
│                                                                 │
│  MEMORY SAVINGS                                                 │
│  ──────────────                                                 │
│  FP32 → FP16: 2x smaller                                        │
│  FP32 → INT8: 4x smaller                                        │
│  FP32 → INT4: 8x smaller                                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Quantization Types

TypeWhen AppliedAccuracySpeedUse Case
Post-Training (PTQ)After trainingGoodFastQuick deployment
Quantization-Aware (QAT)During trainingBestSlowProduction models
Dynamic QuantizationAt runtimeGoodMediumVariable inputs

Post-Training Quantization

python
# PyTorch Post-Training Quantization
import torch.quantization as quant

# Prepare model
model.eval()
model.qconfig = quant.get_default_qconfig('fbgemm')  # For x86
# model.qconfig = quant.get_default_qconfig('qnnpack')  # For ARM

# Fuse layers (Conv+BN+ReLU)
model_fused = quant.fuse_modules(model, [['conv', 'bn', 'relu']])

# Prepare for quantization
model_prepared = quant.prepare(model_fused)

# Calibrate with representative data
with torch.no_grad():
    for batch in calibration_loader:
        model_prepared(batch)

# Convert to quantized model
model_quantized = quant.convert(model_prepared)

Quantization-Aware Training

python
# QAT for better accuracy
model.train()
model.qconfig = quant.get_default_qat_qconfig('fbgemm')

# Prepare for QAT
model_qat = quant.prepare_qat(model)

# Train with fake quantization
for epoch in range(num_epochs):
    for batch in train_loader:
        output = model_qat(batch)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Convert to quantized
model_qat.eval()
model_quantized = quant.convert(model_qat)

Pruning

Pruning Strategies

┌─────────────────────────────────────────────────────────────────┐
│              PRUNING STRATEGIES                                 │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  UNSTRUCTURED PRUNING                                           │
│  ────────────────────                                           │
│  • Remove individual weights                                    │
│  • High sparsity possible (90%+)                                │
│  • Requires sparse hardware/libraries                           │
│  • Example: Remove weights with |w| < threshold                 │
│                                                                 │
│  STRUCTURED PRUNING                                             │
│  ──────────────────                                             │
│  • Remove entire neurons/channels/layers                        │
│  • Lower sparsity (30-70%)                                      │
│  • Works with standard hardware                                 │
│  • Example: Remove channels with low L1 norm                    │
│                                                                 │
│  WHEN TO USE WHAT                                               │
│  ────────────────                                               │
│  Unstructured: Specialized hardware (sparse accelerators)       │
│  Structured: Standard GPUs/CPUs                                 │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Pruning Implementation

python
import torch.nn.utils.prune as prune

# Unstructured pruning (L1)
prune.l1_unstructured(model.conv1, name='weight', amount=0.3)

# Structured pruning (remove channels)
prune.ln_structured(model.conv1, name='weight', amount=0.3, n=2, dim=0)

# Global pruning (across all layers)
parameters_to_prune = [
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
]
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.4,
)

# Make pruning permanent
prune.remove(model.conv1, 'weight')

Graph Optimization

Common Optimizations

┌─────────────────────────────────────────────────────────────────┐
│              GRAPH OPTIMIZATIONS                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  OPERATOR FUSION                                                │
│  ───────────────                                                │
│  Conv → BN → ReLU  →  FusedConvBNReLU                           │
│  Reduces memory bandwidth, kernel launch overhead               │
│                                                                 │
│  CONSTANT FOLDING                                               │
│  ────────────────                                               │
│  Pre-compute operations on constants at compile time            │
│  Example: Fold BN parameters into Conv weights                  │
│                                                                 │
│  DEAD CODE ELIMINATION                                          │
│  ─────────────────────                                          │
│  Remove unused operations and tensors                           │
│                                                                 │
│  LAYOUT OPTIMIZATION                                            │
│  ───────────────────                                            │
│  NCHW vs NHWC - choose optimal for hardware                     │
│  Reduce transpose operations                                    │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

TorchScript Optimization

python
# Convert to TorchScript for optimization
model.eval()

# Tracing (for models without control flow)
traced_model = torch.jit.trace(model, example_input)

# Scripting (for models with control flow)
scripted_model = torch.jit.script(model)

# Optimize for inference
optimized_model = torch.jit.optimize_for_inference(traced_model)

# Save for deployment
optimized_model.save("model_optimized.pt")

Hardware-Specific Optimization

ONNX Runtime

python
import onnxruntime as ort

# Export to ONNX
torch.onnx.export(
    model,
    example_input,
    "model.onnx",
    opset_version=13,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}}
)

# Create optimized session
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

session = ort.InferenceSession(
    "model.onnx",
    sess_options,
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

# Run inference
output = session.run(None, {'input': input_data})

TensorRT (NVIDIA)

┌─────────────────────────────────────────────────────────────────┐
│              TENSORRT OPTIMIZATION                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  WHAT TENSORRT DOES                                             │
│  ──────────────────                                             │
│  • Layer fusion                                                 │
│  • Kernel auto-tuning                                           │
│  • Precision calibration (FP16/INT8)                            │
│  • Memory optimization                                          │
│                                                                 │
│  TYPICAL SPEEDUP                                                │
│  ───────────────                                                │
│  • FP32: 2-3x faster than PyTorch                               │
│  • FP16: 4-6x faster                                            │
│  • INT8: 6-10x faster                                           │
│                                                                 │
│  WORKFLOW                                                       │
│  ────────                                                       │
│  PyTorch → ONNX → TensorRT Engine → Deploy                      │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Batching Strategies

Dynamic Batching

┌─────────────────────────────────────────────────────────────────┐
│              BATCHING STRATEGIES                                │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  NO BATCHING (batch=1)                                          │
│  ─────────────────────                                          │
│  • Lowest latency per request                                   │
│  • Lowest throughput                                            │
│  • GPU underutilized                                            │
│                                                                 │
│  STATIC BATCHING                                                │
│  ───────────────                                                │
│  • Fixed batch size                                             │
│  • Wait for batch to fill                                       │
│  • Higher latency, higher throughput                            │
│                                                                 │
│  DYNAMIC BATCHING                                               │
│  ────────────────                                               │
│  • Batch requests within time window                            │
│  • Balance latency vs throughput                                │
│  • Best for production                                          │
│                                                                 │
│  CONTINUOUS BATCHING (LLMs)                                     │
│  ──────────────────────────                                     │
│  • Add/remove requests mid-batch                                │
│  • Optimal for variable-length sequences                        │
│  • Used by vLLM, TGI                                            │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Latency Profiling

Profiling Tools

python
# PyTorch Profiler
from torch.profiler import profile, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
) as prof:
    for _ in range(10):
        model(input_data)

# Print summary
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# Export for visualization
prof.export_chrome_trace("trace.json")

Latency Breakdown

┌─────────────────────────────────────────────────────────────────┐
│              LATENCY BREAKDOWN                                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  TYPICAL INFERENCE LATENCY COMPONENTS                           │
│  ────────────────────────────────────                           │
│                                                                 │
│  ┌──────────────────────────────────────────────────────┐       │
│  │ Data Transfer (CPU→GPU)           │████░░░░░░│ 15%   │       │
│  │ Preprocessing                     │██░░░░░░░░│ 10%   │       │
│  │ Model Forward Pass                │████████░░│ 60%   │       │
│  │ Postprocessing                    │█░░░░░░░░░│ 5%    │       │
│  │ Data Transfer (GPU→CPU)           │██░░░░░░░░│ 10%   │       │
│  └──────────────────────────────────────────────────────┘       │
│                                                                 │
│  OPTIMIZATION PRIORITY                                          │
│  ─────────────────────                                          │
│  1. Model forward pass (quantization, pruning)                  │
│  2. Data transfer (pinned memory, async)                        │
│  3. Pre/post processing (GPU acceleration)                      │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Production Deployment Checklist

┌─────────────────────────────────────────────────────────────────┐
│              DEPLOYMENT CHECKLIST                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  OPTIMIZATION                                                   │
│  ────────────                                                   │
│  □ Model converted to inference format (TorchScript/ONNX)       │
│  □ Quantization applied and validated                           │
│  □ Graph optimizations enabled                                  │
│  □ Batch size tuned for target hardware                         │
│                                                                 │
│  VALIDATION                                                     │
│  ──────────                                                     │
│  □ Accuracy verified post-optimization                          │
│  □ Latency meets SLA (p50, p99)                                 │
│  □ Memory usage within limits                                   │
│  □ Throughput meets requirements                                │
│                                                                 │
│  MONITORING                                                     │
│  ──────────                                                     │
│  □ Latency metrics instrumented                                 │
│  □ Error rate tracking                                          │
│  □ Resource utilization monitoring                              │
│  □ Model drift detection                                        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Operational Checklist

Pre-Optimization Gate

CheckRequiredOwnerNotes
Baseline latency measuredML Engineerp50, p95, p99
Target SLO definedProduct + ML Leade.g., p99 < 100ms
Accuracy threshold setML LeadMax acceptable degradation
Hardware target identifiedPlatformGPU/CPU/Edge

Optimization Validation Gate

CheckRequiredHow to Verify
Accuracy within thresholdCompare to baseline on holdout
Latency meets SLOLoad test with realistic traffic
Memory within limitsProfile on target hardware
Throughput meets requirementsStress test at peak load
Edge cases handledTest with max-length inputs

Production Deployment Gate

CheckRequiredOwner
Model converted to deployment formatML Engineer
Quantization testedML Engineer
Batching strategy configuredML Engineer
Fallback mechanism readyML Lead
Latency monitoring instrumentedPlatform
SLO alerting configuredPlatform

📎 Cross-References