当前位置: 首页 > news >正文

rust-candle学习笔记10-使用Embedding

参考:about-pytorch

candle-nn提供embedding()初始化Embedding方法:

pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {let embeddings = vb.get_with_hints((in_size, out_size),"weight",crate::Init::Randn {mean: 0.,stdev: 1.,},)?;Ok(Embedding::new(embeddings, out_size))
}

 candle Embedding初体验:

其中Tokenizer和dataset的构造详情参考:rust-candle学习笔记9-使用tokenizers加载qwen3分词,使用分词器处理文本

use candle_nn::{embedding, Embedding, Module, VarBuilder, VarMap};fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let vocab_size = tokenizer.get_vocab_size(true);let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let dataset = TokenDataset::new(text, tokenizer, 32, 16, device.clone())?;let (inputs, targets) = dataset.get_item(0)?;println!(" inputs: {:?}\n", inputs);println!(" targets: {:?}\n", targets);let len = dataset.len();println!("{:?}", len);let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let embedding = embedding(vocab_size, 5, vb)?;let x_embedding = embedding.forward(&inputs)?;let y_embedding = embedding.forward(&targets)?;println!(" x_embedding: {:?}\n", x_embedding);println!("{:?}", x_embedding.to_vec2::<f32>()?);println!(" y_embedding: {:?}\n", y_embedding);println!("{:?}", y_embedding.to_vec2::<f32>()?);Ok(())
}

实现正余弦位置编码:

struct PositionEmbedding {pos_embedding: Tensor,device: Device
}
impl PositionEmbedding {fn new(seq_len: usize, embedding_dim: usize, device: Device) -> Result<Self> {if embedding_dim % 2 != 0 {return Err(Box::new(candle_core::Error::msg("embedding_dim must be even")));}let mut pos_embedding_vec: Vec<f32> = Vec::with_capacity(seq_len * embedding_dim);let w_const: f32 = 10000.0;for t in 0..seq_len {let i_max = embedding_dim / 2;for i in 0..i_max {let denominator = w_const.powf(2.0 * i as f32 / embedding_dim as f32);let pos_sin_i = (t as f32 / denominator).sin();let pos_cos_i = (t as f32 / denominator).cos();pos_embedding_vec.push(pos_sin_i);pos_embedding_vec.push(pos_cos_i);}}let pos_embedding = Tensor::from_vec(pos_embedding_vec, (seq_len, embedding_dim), &device)?;Ok(Self { pos_embedding, device })}
}

测试:

注意:candle 不同维度tensor相加直接用+会报错,要显示的调用广播加,高维tensor和低维tensor谁加谁都可以

fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let vocab_size = tokenizer.get_vocab_size(true);let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let seq_len = 32;let dataset = TokenDataset::new(text, tokenizer, seq_len, 16, device.clone())?;let batch_size: usize = 6;let mut loader = DataLoader::new(dataset, batch_size, true);loader.reset();let (x, y) = loader.next().unwrap()?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let embedding_dim: usize = 256;let embedding = embedding(vocab_size, embedding_dim, vb)?;let x_embedding = embedding.forward(&x)?;let y_embedding = embedding.forward(&y)?;println!(" x_embedding: {:?}\n", x_embedding);println!(" y_embedding: {:?}\n", y_embedding);let pos_embedding = PositionEmbedding::new(seq_len, embedding_dim, device.clone())?;let pos_emb = pos_embedding.pos_embedding;// candle 不同维度tensor相加直接用+会报错,// 广播加要显示的调用// 下面两种方式都可以let x_input = x_embedding.broadcast_add(&pos_emb)?;// let x_input = pos_emb.broadcast_add(&x_embedding)?;println!(" x_input: {:?}\n", x_input);Ok(())
}

http://www.xdnf.cn/news/353917.html

相关文章:

  • Unity基础学习(九)输入系统全解析:鼠标、键盘与轴控制
  • SSHv2公钥认证示例-Paramiko复用 Transport 连接
  • 港大今年开源了哪些SLAM算法?
  • Github 热点项目 Cursor开源代替,AI代理+可视化编程!支持本地部署的隐私友好型开发神器。
  • LVDS系列11:Xilinx Ultrascale系可编程输入延迟(一)
  • 聊聊四种实时通信技术:短轮询、长轮询、WebSocket 和 SSE
  • 推挽输出、开漏输出、上拉电阻、下拉电阻、低边驱动、高边驱动【简版总结】
  • 【Git】查看tag
  • 基于阿里云DataWorks的物流履约时效离线分析
  • STM32定时器5触发定时器4启动
  • 【软件测试】软件缺陷(Bug)的详细描述
  • 使用 NV‑Ingest、Unstructured 和 Elasticsearch 处理非结构化数据
  • 利用GPT实现油猴脚本—网页滚动(优化版)
  • 豆包:基于多模态交互的智能心理咨询机器人系统设计与效果评估——情感计算框架下的对话机制创新
  • Spark,在shell中运行RDD程序
  • 【SQL系列】多表关联更新
  • 手持气象仪:能够实时测量多种气象参数,保数据采集的准确性与实时性
  • 掌握Multi-Agent实践(三):ReAct Agent集成Bing和Google搜索功能,采用推理与执行交替策略,增强处理复杂任务能力
  • Spring Boot 框架概述
  • 【计算机视觉】Car-Plate-Detection-OpenCV-TesseractOCR:车牌检测与识别
  • 【css】css统一设置变量
  • 更新 / 安装 Nvidia Driver 驱动 - Ubuntu - 2
  • 数据类型详解(布尔值、整型、浮点型、字符串等)-《Go语言实战指南》
  • istio in action之Gateway流量入口与安全
  • 分析NVIDIA的股价和业绩暴涨的原因
  • Zabbix监控 RabbitMQ 指定消息队列名称(pull_alarms )的消费者
  • 富乐德传感技术盘古信息 | 锚定“未来工厂”新坐标,开启传感器制造行业数字化转型新征程
  • IC解析之TPS92682-Q1(汽车LED灯控制IC)
  • 【C/C++】C语⾔内存函数
  • [Errno 122] Disk quota exceeded