Skip to content

Native AI & Machine Learning Cutting Edge

Lập trình AI hoàn toàn bằng Rust — Không Python, không GIL, không overhead

Tại sao Rust cho AI?

Python thống trị AI training, nhưng inference (chạy models đã trained) đang chuyển sang Rust:

┌─────────────────────────────────────────────────────────────────────┐
│                 PYTHON vs RUST FOR LLM INFERENCE                    │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  Metric              │  Python + PyTorch  │  Rust (Candle/Burn)    │
│  ────────────────────┼────────────────────┼────────────────────    │
│  Cold Start          │  5-15 seconds      │  0.5-2 seconds         │
│  Memory Overhead     │  500MB-2GB         │  50-200MB              │
│  Token/sec (LLM)     │  Baseline          │  2-5x faster           │
│  Deployment Size     │  1-5GB             │  50-200MB              │
│  GIL (concurrency)   │  Bottleneck        │  N/A (no GIL)          │
│                                                                     │
│  Use Cases:                                                         │
│  Python: Training, Research, Prototyping                            │
│  Rust: Production Inference, Edge, WASM, Embedded                   │
└─────────────────────────────────────────────────────────────────────┘

1. Candle (by HuggingFace)

Candle là minimalist ML framework từ HuggingFace — team đứng sau Transformers library.

Đặc điểm

  • Pure Rust (no Python bindings)
  • CPU + CUDA + Metal acceleration
  • Models: BERT, Llama 2, Mistral, Whisper, Stable Diffusion
  • Designed cho production inference

Setup

toml
# Cargo.toml
[dependencies]
candle-core = "0.4"
candle-nn = "0.4"
candle-transformers = "0.4"
hf-hub = "0.3"  # HuggingFace model downloads
tokenizers = "0.15"
anyhow = "1"

Ví dụ: BERT Embedding

rust
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
use anyhow::Result;

fn main() -> Result<()> {
    // 1. Load tokenizer và model từ HuggingFace Hub
    let api = Api::new()?;
    let repo = api.repo(Repo::new(
        "sentence-transformers/all-MiniLM-L6-v2".to_string(),
        RepoType::Model,
    ));
    
    let tokenizer_path = repo.get("tokenizer.json")?;
    let weights_path = repo.get("model.safetensors")?;
    let config_path = repo.get("config.json")?;
    
    // 2. Initialize
    let device = Device::Cpu;  // hoặc Device::cuda_if_available(0)?
    let tokenizer = Tokenizer::from_file(tokenizer_path)?;
    let config: Config = serde_json::from_str(&std::fs::read_to_string(config_path)?)?;
    
    let vb = unsafe {
        VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)?
    };
    let model = BertModel::load(vb, &config)?;
    
    // 3. Tokenize input
    let text = "Rust is the future of systems programming";
    let encoding = tokenizer.encode(text, true)?;
    let input_ids = Tensor::new(encoding.get_ids(), &device)?.unsqueeze(0)?;
    let token_type_ids = Tensor::zeros_like(&input_ids)?;
    
    // 4. Get embeddings
    let embeddings = model.forward(&input_ids, &token_type_ids)?;
    
    // Mean pooling
    let (_, seq_len, hidden_size) = embeddings.dims3()?;
    let pooled = embeddings.sum(1)? / (seq_len as f64);
    
    println!("Embedding shape: {:?}", pooled.dims());
    println!("First 5 values: {:?}", pooled.flatten_all()?.to_vec1::<f32>()?[..5]);
    
    Ok(())
}

LLM Inference (Llama 2 / Mistral)

rust
use candle_transformers::models::llama::{Llama, LlamaConfig};
use candle_transformers::generation::LogitsProcessor;

fn generate_text(model: &Llama, prompt: &str, max_tokens: usize) -> Result<String> {
    let mut logits_processor = LogitsProcessor::new(42, Some(0.8), Some(0.95));
    let mut tokens = tokenize(prompt)?;
    
    for _ in 0..max_tokens {
        let input = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
        let logits = model.forward(&input, tokens.len() - 1)?;
        
        let next_token = logits_processor.sample(&logits)?;
        tokens.push(next_token);
        
        if next_token == eos_token_id {
            break;
        }
    }
    
    Ok(decode(&tokens)?)
}

2. Burn Framework

Burn = Modern deep learning framework với dynamic computation graphs (như PyTorch).

Đặc điểm

  • Multi-backend: WGPU (GPU), NdArray (CPU), Candle, LibTorch
  • #[derive(Module)] macro cho custom modules
  • Automatic differentiation
  • Training + Inference

Setup

toml
[dependencies]
burn = "0.13"
burn-ndarray = "0.13"  # CPU backend
# burn-wgpu = "0.13"   # GPU backend

Ví dụ: Simple Neural Network

rust
use burn::prelude::*;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig, Relu};
use burn::tensor::backend::Backend;

// Define model architecture
#[derive(Module, Debug)]
pub struct MLP<B: Backend> {
    linear1: Linear<B>,
    linear2: Linear<B>,
    activation: Relu,
}

impl<B: Backend> MLP<B> {
    pub fn new(device: &B::Device) -> Self {
        let linear1 = LinearConfig::new(784, 256).init(device);
        let linear2 = LinearConfig::new(256, 10).init(device);
        
        Self {
            linear1,
            linear2,
            activation: Relu::new(),
        }
    }
    
    pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        let x = self.linear1.forward(x);
        let x = self.activation.forward(x);
        self.linear2.forward(x)
    }
}

fn main() {
    type MyBackend = burn_ndarray::NdArray<f32>;
    let device = Default::default();
    
    let model = MLP::<MyBackend>::new(&device);
    
    // Random input (batch_size=32, features=784)
    let input = Tensor::<MyBackend, 2>::random(
        [32, 784],
        burn::tensor::Distribution::Normal(0.0, 1.0),
        &device,
    );
    
    let output = model.forward(input);
    println!("Output shape: {:?}", output.dims());  // [32, 10]
}

Training Loop

rust
use burn::optim::{AdamConfig, Optimizer};
use burn::train::{LearnerBuilder, TrainStep, ValidStep};

#[derive(Clone)]
struct MNISTBatch<B: Backend> {
    images: Tensor<B, 2>,
    targets: Tensor<B, 1, Int>,
}

impl<B: Backend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for MLP<B> {
    fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
        let output = self.forward(batch.images);
        let loss = CrossEntropyLossConfig::new()
            .init()
            .forward(output.clone(), batch.targets.clone());
        
        TrainOutput::new(self, loss.backward(), ClassificationOutput { loss, output })
    }
}

fn train() {
    let learner = LearnerBuilder::new("./artifacts")
        .with_optimizer(AdamConfig::new())
        .with_lr_scheduler(ExponentialLrSchedulerConfig::new(1e-3, 0.99))
        .with_epochs(10)
        .build(model, train_dataloader, valid_dataloader);
    
    let trained_model = learner.fit();
}

3. Edge AI: WASM & IoT

Candle on WASM

toml
# Cargo.toml
[lib]
crate-type = ["cdylib"]

[dependencies]
candle-core = { version = "0.4", default-features = false }
wasm-bindgen = "0.2"
rust
use wasm_bindgen::prelude::*;
use candle_core::{Device, Tensor};

#[wasm_bindgen]
pub fn predict(input: &[f32]) -> Vec<f32> {
    let device = Device::Cpu;
    let tensor = Tensor::from_slice(input, (1, input.len()), &device).unwrap();
    
    // Run inference...
    let output = model.forward(&tensor).unwrap();
    
    output.to_vec1::<f32>().unwrap()
}

Build cho WASM

bash
# Install target
rustup target add wasm32-unknown-unknown

# Build
cargo build --target wasm32-unknown-unknown --release

# Optimize size
wasm-opt -Os -o output.wasm target/wasm32-unknown-unknown/release/mymodel.wasm

IoT / Embedded (no_std)

rust
#![no_std]
#![no_main]

use candle_core::{Device, Tensor, DType};

// Quantized model cho embedded devices
fn inference_quantized(input: &[i8]) -> i8 {
    // INT8 quantized inference
    // Giảm 4x memory, tăng throughput trên ARM chips
}

4. Quantization cho Production

Quantization giảm model size và tăng speed:

PrecisionMemorySpeedAccuracy Loss
FP32100%Baseline0%
FP1650%~2x<0.1%
INT825%~4x0.5-1%
INT412.5%~8x1-3%
rust
use candle_core::quantized::{QMatMul, QTensor};

// Load quantized weights
let weights = QTensor::new(raw_weights, candle_core::quantized::GgmlDType::Q4_0)?;

// Quantized matrix multiplication
let output = QMatMul::from_qtensor(weights)?.forward(&input)?;

Bảng Tóm tắt

FrameworkUse CaseBackend
CandleLLM inference, HuggingFace modelsCPU, CUDA, Metal
BurnCustom models, Training + InferenceWGPU, NdArray, LibTorch
tractONNX inferenceCPU
tch-rsPyTorch bindingsLibTorch

Khi nào dùng gì?

  • HuggingFace models (BERT, Llama): → Candle
  • Custom architecture + Training: → Burn
  • ONNX models từ Python: → tract
  • Cần full PyTorch compatibility: → tch-rs