当前位置: 首页 > news >正文

rust-candle学习笔记11-实现一个简单的自注意力

参考:about-pytorch

定义ScaledDotProductAttention结构体:

use candle_core::{Result, Device, Tensor};
use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};struct ScaledDotProductAttention {wq: Linear,wk: Linear,wv: Linear,d_model: Tensor,device: Device,
}

为ScaledDotProductAttention结构体实现new方法:

impl ScaledDotProductAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { wq: linear_no_bias(embedding_dim, out_dim, vb.pp("wq"))?, wk: linear_no_bias(embedding_dim, out_dim, vb.pp("wk"))?, wv: linear_no_bias(embedding_dim, out_dim, vb.pp("wv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

为结构体实现Module的forward trait:

impl Module for ScaledDotProductAttention {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let q = self.wq.forward(xs)?;let k = self.wk.forward(xs)?;let v = self.wv.forward(xs)?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

融合qkv实现:

定义ScaledDotProductAttentionFusedQKV结构体:

struct ScaledDotProductAttentionFusedQKV {w_qkv: Linear,d_model: Tensor,device: Device,
}

为结构体实现new方法:

impl ScaledDotProductAttentionFusedQKV {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

为结构体实现forward trait:

impl Module for ScaledDotProductAttentionFusedQKV {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let qkv = self.w_qkv.forward(xs)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;let q = qkv.get_on_dim(2, 0)?;let q = q.reshape((batch_size, seq_len, ()))?;let k = qkv.get_on_dim(2, 1)?;let k = k.reshape((batch_size, seq_len, ()))?;let v = qkv.get_on_dim(2, 2)?;let v = v.reshape((batch_size, seq_len, ()))?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

测试:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;// let model = ScaledDotProductAttention::new(vb.clone(), 3, 2, device.clone())?;let model = ScaledDotProductAttentionFusedQKV::new(vb.clone(), 3, 2, device.clone())?;let output = model.forward(&input)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

http://www.xdnf.cn/news/352063.html

相关文章:

  • 前端工程化和性能优化问题详解
  • Vue3 中 ref 与 reactive 的区别及底层原理详解
  • fakebook
  • 【Linux】深入拆解Ext文件系统:从磁盘物理结构到Linux文件管理
  • 在企业级项目中高效使用 Maven-mvnd
  • 2025-05-10-FFmepg库裁切有水印的视频
  • docker 日志暴露方案 (带权限 还 免费 版本)
  • 企业如何将钉钉付款单高效集成到金蝶云星空?
  • 高频微服务面试题总结
  • 【MySQL】联合查询
  • 自适应混合索引创建与管理:一种智能数据库优化机制的研究
  • 高并发内存池(二):项目的整体框架以及Thread_Cache的结构设计
  • 怎么用idea打jar包
  • 从“山谷论坛”看AI七剑下天山
  • 集成管理工具Gitlab
  • 高清屏幕录像工具 Mirillis Action v4.45.0
  • kitty 终端ssh 命令远程无法正常输入命令
  • 第J7周:ResNeXt解析
  • 【Linux】环境变量(图文)
  • Servlet、HttpServlet 和 DispatcherServlet 区别与关系
  • SPN技术介绍
  • Redis 常见数据类型
  • 新闻发稿筛选媒体核心标准:影响力、适配性与合规性
  • 【LUT技术专题】ECLUT代码解读
  • 如何从极狐GitLab 容器镜像库中删除容器镜像?
  • 解决osx-arm64平台上conda默认源没有提供 python=3.7 的官方编译版本的问题
  • android-ndk开发(11): 安装 repo 命令
  • MySQL + Elasticsearch:为什么要使用ES,使用场景与架构设计详解
  • NAT穿越
  • 力扣-24.两两交换链表中的结点