HuggingFace Candle MNIST Linear Model Training

Rust
HuggingFace
Candle
ML
Author

曳影

Published

October 16, 2023

使用HuggingFace的Rust机器学习框架训练MNIST

环境

  • Rust: 1.75.0-nightly
  • candle-core: 0.3.0
  • candle-nn: 0.3.0
  • candle-datasets: 0.3.0

注意candle-nn当前版本中依赖了Rust nightly

Cargo.toml内容如下

  1. rand: 随机数
  2. anyhow: 处理异常
  3. clap: 解析命令行参数
[package]
name = "linear_mnist"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }
rand = "0.8.5"
anyhow = "1"
clap = { version = "4.4.4", features = ["derive"] }
candle-datasets = { git = "https://github.com/huggingface/candle.git", version = "0.3.0" }

创建项目并安装Candle相关模块

  1. 使用cargo new创建linear_mnist项目
  2. 进入项目目录
  3. 安装candle三个模块
    • candle-core
    • candle-nn
    • candle-datasets
  4. 安装其他依赖库
    • rand
    • anyhow
    • clap

具体操作如下:

cargo new linear_mnist
cd linear_mnist

cargo add --git https://github.com/huggingface/candle.git candle-core
cargo add --git https://github.com/huggingface/candle.git candle-nn
cargo add --git https://github.com/huggingface/candle.git candle-datasets

代码

导入相关依赖

  1. 导入clap::Parser解析命令行参数
  2. 导入candle_core的相关依赖
    • Device: 数据计算时放置的设备
    • Result: 处理异常
    • Tensor: 张量数据类型
    • D: 是一个enum,包含Minus1Minus2
    • DType: 数据类型enum结构,包含支持的数据类型
  3. 导入candle-nn的相关依赖
    • loss: 损失函数相关操作
    • ops: 函数操作,如log_softmax
    • Linear: 线性模型
    • Module: 由于Linear的依赖
    • Optimizer: 优化器
    • VarBuilder: 构建变量
    • VarMap: 用于存储模型变量
use clap::{ Parser };
use candle_core::{ Device, Result, Tensor, D, DType };
use candle_nn::{ loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap };

定义相关配置

  1. 定义图像维度数量和标签数量的常量
  2. 定义命令行参数解析,并添加指令宏#[derive(Parser)],可以使用clap::Parser解析命令行参数
    • learning_rate: 学习率
    • epochs: 模型训练迭代次数
    • save_model_path: 训练好的模型保存路径
    • load_model_path: 加载预训练模型路径
    • local_mnist: 本地MNIST数据集目录
  3. 定义训练参数结构体TrainingArgs
  4. 定义线性模型结构体LinearModel

具体代码如下:

const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;

#[derive(Parser)]
struct Args {
    #[arg(long)]
    learning_rate: Option<f64>,

    #[arg(long, default_value_t = 10)]
    epochs: usize,

    #[arg(long)]
    save_model_path: Option<String>,

    #[arg(long)]
    load_model_path: Option<String>,

    #[arg(long)]
    local_mnist: Option<String>,
}

struct TrainingArgs {
    learning_rate: f64,
    load_path: Option<String>,
    save_path: Option<String>,
    epochs: usize,
}

struct LinearModel {
    linear: Linear,
}

定义模型

  1. 定义Model trait
  2. LinearModel实现Model trait
    • new: 初始化模型
    • forward: 模型结构,前向传播
  3. linear_z是具体创建Linear模型
    • 创建模型张量变量并调用candle-nn::Linear创建线性模型返回

具体代码如下:

trait Model: Sized {
    fn new(vs: VarBuilder) -> Result<Self>;
    fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}

impl Model for LinearModel {
    fn new(vs: VarBuilder) -> Result<Self> {
        let linear: Linear = linear_z(IMAGE_DIM, LABELS, vs)?;
        Ok(Self { linear })
    }

    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        self.linear.forward(xs)
    }
}

fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
    let ws: Tensor = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
    let bs: Tensor = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
    Ok(Linear::new(ws, Some(bs)))
}

定义模型训练函数

  1. 输入参数
    • m: 数据集
    • args: 训练参数TrainingArgs
  2. 获取或设置模型运算的设备Device::Cpu
  3. 从数据集m中获取训练数据和标签,测试数据和标签
  4. 创建varmap用来存储模型参数
  5. 创建vs变量构造,存储模型参数,并将其传入到Model::new
  6. 如果命令行传入load_model_path,则会加载预训练模型
  7. 创建优化器SGD
  8. 根据epochs迭代训练模型
    • 前向传播得到logits
    • 计算概率log_softmax
    • 计算损失函数值
    • 反向传播sgd.backward_step()
    • 输入测试数据得到测试数据准确率test_accuracy
    • 每个epoch花费的时间epoch_duration
  9. 如果命令传入save_model_path,则会保存模型参数
    • 确保存放模型的目录已经建立

具体代码如下:

fn train<M: Model>(
    m: candle_datasets::vision::Dataset,
    args: &TrainingArgs) -> anyhow::Result<()> {

    let dev = Device::Cpu;

    let train_labels = m.train_labels;
    let train_images = m.train_images.to_device(&dev)?;
    let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
    let test_images = m.test_images.to_device(&dev)?;
    let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;

    let mut varmap = VarMap::new();
    let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
    let model = M::new(vs.clone())?;

    // Load Pre-trained Model Parameters
    if let Some(load_path) = &args.load_path {
        println!("Loading model from {}", load_path);
        let _ = varmap.load(load_path);
    }

    // Create Optimizer
    let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;

    // Iterate training model
    for epoch in 1..=args.epochs {
        let start_time = std::time::Instant::now();
        let logits = model.forward(&train_images)?;
        let log_sm = ops::log_softmax(&logits, D::Minus1)?;
        let loss = loss::nll(&log_sm, &train_labels)?;

        sgd.backward_step(&loss)?;

        let test_logits = model.forward(&test_images)?;
        let sum_ok = test_logits
            .argmax(D::Minus1)?
            .eq(&test_labels)?
            .to_dtype(DType::F32)?
            .sum_all()?
            .to_scalar::<f32>()?;
        let test_accuracy = sum_ok / test_labels.dims1()? as f32;
        let end_time = std::time::Instant::now();
        let epoch_duration = end_time.duration_since(start_time);
        println!("Epoch: {epoch:4} Train Loss: {:8.5} Test Acc: {:5.2}% Epoch duration: {:.2} second.",
                 loss.to_scalar::<f32>()?, test_accuracy * 100., epoch_duration.as_secs_f64());
    }

    // Save Model Parameters
    if let Some(save_path) = &args.save_path {
        println!("Saving trained weight in {save_path}");
        varmap.save(save_path)?
    }
    Ok(())
}

Main函数

  1. 解析命令行参数Args
  2. 根据local_mnist命令行参数指定的目录加载MNIST数据集
  3. 设置学习率
  4. 创建模型训练参数TrainingArgs类型变量training_args并填充设置好的参数
  5. 调用模型训练函数train::<LinearModel>(m, &training_args),传入数据集模型训练参数
fn main() ->anyhow::Result<()> {
    let args: Args = Args::parse();
    let m: candle_datasets::vision::Dataset = if let Some(directory) = args.local_mnist {
        candle_datasets::vision::mnist::load_dir(directory)?
    } else {
        candle_datasets::vision::mnist::load()?
    };

    println!("Train Images: {:?}", m.train_images.shape());
    println!("Train Labels: {:?}", m.train_labels.shape());
    println!("Test  Images: {:?}", m.test_images.shape());
    println!("Test  Labels: {:?}", m.test_labels.shape());

    let default_learning_rate: f64 = 0.1;

    let training_args = TrainingArgs {
        epochs: args.epochs,
        learning_rate: args.learning_rate.unwrap_or(default_learning_rate),
        load_path: args.load_model_path,
        save_path: args.save_model_path,
    };

    train::<LinearModel>(m, &training_args)
}

训练

  1. 如果saved_model不存在,则需要先创建该目录

  2. 目录结构如下

linear_mnist
├── Cargo.lock
├── Cargo.toml
├── dataset
│   ├── t10k-images-idx3-ubyte
│   ├── t10k-labels-idx1-ubyte
│   ├── train-images-idx3-ubyte
│   └── train-labels-idx1-ubyte
├── saved_model
│   └── minst.safetensors
└── src
    └── main.rs
  1. 训练并保存模型参数 MNIST Training

  2. 加载预训练模型继续训练 MNIST Pre-trained

  3. 完整代码地址

Candel Linear Model Training MNIST Classification on Github

Reference

Candle MNIST Training