rust-candle学习笔记10-使用Embedding
参考:about-pytorch
candle-nn提供embedding()初始化Embedding方法:
pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {let embeddings = vb.get_with_hints((in_size, out_size),"weight",crate::Init::Randn {mean: 0.,stdev: 1.,},)?;Ok(Embedding::new(embeddings, out_size))
}
candle Embedding初体验:
其中Tokenizer和dataset的构造详情参考:rust-candle学习笔记9-使用tokenizers加载qwen3分词,使用分词器处理文本
use candle_nn::{embedding, Embedding, Module, VarBuilder, VarMap};fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let vocab_size = tokenizer.get_vocab_size(true);let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let dataset = TokenDataset::new(text, tokenizer, 32, 16, device.clone())?;let (inputs, targets) = dataset.get_item(0)?;println!(" inputs: {:?}\n", inputs);println!(" targets: {:?}\n", targets);let len = dataset.len();println!("{:?}", len);let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let embedding = embedding(vocab_size, 5, vb)?;let x_embedding = embedding.forward(&inputs)?;let y_embedding = embedding.forward(&targets)?;println!(" x_embedding: {:?}\n", x_embedding);println!("{:?}", x_embedding.to_vec2::<f32>()?);println!(" y_embedding: {:?}\n", y_embedding);println!("{:?}", y_embedding.to_vec2::<f32>()?);Ok(())
}
实现正余弦位置编码:
struct PositionEmbedding {pos_embedding: Tensor,device: Device
}
impl PositionEmbedding {fn new(seq_len: usize, embedding_dim: usize, device: Device) -> Result<Self> {if embedding_dim % 2 != 0 {return Err(Box::new(candle_core::Error::msg("embedding_dim must be even")));}let mut pos_embedding_vec: Vec<f32> = Vec::with_capacity(seq_len * embedding_dim);let w_const: f32 = 10000.0;for t in 0..seq_len {let i_max = embedding_dim / 2;for i in 0..i_max {let denominator = w_const.powf(2.0 * i as f32 / embedding_dim as f32);let pos_sin_i = (t as f32 / denominator).sin();let pos_cos_i = (t as f32 / denominator).cos();pos_embedding_vec.push(pos_sin_i);pos_embedding_vec.push(pos_cos_i);}}let pos_embedding = Tensor::from_vec(pos_embedding_vec, (seq_len, embedding_dim), &device)?;Ok(Self { pos_embedding, device })}
}
测试:
注意:candle 不同维度tensor相加直接用+会报错,要显示的调用广播加,高维tensor和低维tensor谁加谁都可以
fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let vocab_size = tokenizer.get_vocab_size(true);let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let seq_len = 32;let dataset = TokenDataset::new(text, tokenizer, seq_len, 16, device.clone())?;let batch_size: usize = 6;let mut loader = DataLoader::new(dataset, batch_size, true);loader.reset();let (x, y) = loader.next().unwrap()?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let embedding_dim: usize = 256;let embedding = embedding(vocab_size, embedding_dim, vb)?;let x_embedding = embedding.forward(&x)?;let y_embedding = embedding.forward(&y)?;println!(" x_embedding: {:?}\n", x_embedding);println!(" y_embedding: {:?}\n", y_embedding);let pos_embedding = PositionEmbedding::new(seq_len, embedding_dim, device.clone())?;let pos_emb = pos_embedding.pos_embedding;// candle 不同维度tensor相加直接用+会报错,// 广播加要显示的调用// 下面两种方式都可以let x_input = x_embedding.broadcast_add(&pos_emb)?;// let x_input = pos_emb.broadcast_add(&x_embedding)?;println!(" x_input: {:?}\n", x_input);Ok(())
}