当前位置: 首页 > java >正文

rust-candle学习笔记12-实现因果注意力

参考:about-pytorch

定义结构体:

struct CausalAttention {w_qkv: Linear,dropout: Dropout, d_model: Tensor,mask: Tensor,device: Device,   
}

定义new方法:

impl CausalAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, seq_len: usize, dropout: f32, device: Device) -> Result<Self> {Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,mask: Tensor::tril2(seq_len, DType::U32, &device)?,dropout: Dropout::new(dropout),device})}
}

定义forward方法:

    fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> { let qkv = self.w_qkv.forward(x)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;let q = qkv.get_on_dim(2, 0)?;let q = q.reshape((batch_size, seq_len, ()))?;let k = qkv.get_on_dim(2, 1)?;let k = k.reshape((batch_size, seq_len, ()))?;let v = qkv.get_on_dim(2, 2)?;let v = v.reshape((batch_size, seq_len, ()))?;let mut attn_score = q.matmul(&k.t()?)?;// println!("attn_score: {:?}\n", attn_score.to_vec3::<f32>()?);let dim = attn_score.rank() - 1;let mask_dim = attn_score.dims()[dim];let mask = self.mask.broadcast_as(attn_score.shape())?;// println!("mask: {:?}\n", mask);// println!("mask: {:?}\n", mask.to_vec3::<u32>()?);attn_score = masked_fill(&attn_score, &mask, f32::NEG_INFINITY)?;// println!("attn_score: {:?}\n", attn_score);// println!("attn_score: {:?}\n", attn_score.to_vec3::<f32>()?);let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?; let attn_weights = ops::softmax(&attn_score, dim)?;// println!("attn_weights: {:?}\n", attn_weights);// println!("attn_weights: {:?}\n", attn_weights.to_vec3::<f32>()?); let attn_weights = self.dropout.forward(&attn_weights, train)?;// println!("dropout attn_weights: {:?}\n", attn_weights);// println!("dropout attn_weights: {:?}\n", attn_weights.to_vec3::<f32>()?); let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}

测试:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;let model = CausalAttention::new(vb.clone(), 3, 2, 6, 0.5, device.clone())?;let output = model.forward(&input, true)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

http://www.xdnf.cn/news/5298.html

相关文章:

  • 有效的括号(简单)
  • ESP32配置GPIO,实现每0.5秒翻转LED电平
  • python笔记和练习----少儿编程课程【阶段二(二)】
  • C++--类的构造函数与初始化列表差异
  • 抖音视频上传功能测试全维度拆解——从基础功能到隐藏缺陷的深度挖掘
  • 【八股消消乐】项目中如何优化JVM内存分配?
  • [题解]2023CCPC黑龙江省赛 - Ethernet
  • Java多线程同步方法ReentrantLock显式锁实现方式
  • Python数据分析
  • Spring 6.x 详解介绍
  • 【从零实现JsonRpc框架#1】Json库介绍
  • 基于NI-PXI的HIL系统开发
  • MySQL 1366 - Incorrect string value:错误
  • MySQL:视图
  • 串口屏调试 1.0
  • ComfyUI 如何安装ComfyUI_SLK_joy_caption_two
  • window环境下,如何通过USB接口控制打印机
  • 质心均匀体(引力屏蔽技术)
  • 算法训练营第十三天|226.翻转二叉树、101. 对称二叉树、 104.二叉树的最大深度、111.二叉树的最小深度
  • 多模态大模型中的视觉分词器(Tokenizer)前沿研究介绍
  • 【入门】数字走向II
  • JavaScript 数组去重:11 种方法对比与实战指南
  • 什么是 B2B?2B 产品销售怎么找客户?
  • Unity基础学习(十)Camera组件
  • [ctfshow web入门] web67
  • JVM对象创建内存分配
  • [特殊字符]️ 快速检测与修复TLS 1.0/1.1漏洞指南
  • 人形机器人:主控芯片
  • 红黑树算法笔记(二)性能对比实验
  • 解密数据结构之位图和布隆过滤器