当前位置: 首页 > news >正文

rust-candle学习笔记13-实现多头注意力

参考:about-pytorch

定义结构体:

use core::f32;use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{embedding, linear_no_bias, linear, ops, Dropout, Linear, Module, VarBuilder, VarMap};struct MultiHeadAttention {w_qkv: Linear,dropout: Dropout, d_model: Tensor,mask: Tensor,out_proj: Linear,device: Device,out_dim: usize,num_heads: usize,head_dim: usize,
}

定义初始化方法:

impl MultiHeadAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, seq_len: usize, num_heads: usize, drop_p: f32, device: Device) -> Result<Self> {if out_dim % num_heads != 0 {return Err(candle_core::Error::msg("out_dim must be divisible by num_heads"));}Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?, dropout: Dropout::new(drop_p), d_model: Tensor::new(embedding_dim as f32, &device)?, mask: Tensor::tril2(seq_len, DType::U32, &device)?, out_proj: linear(out_dim, out_dim, vb.pp("out_proj"))?, device, out_dim, num_heads, head_dim: out_dim / num_heads, })}
}

定义forward方法:

fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> {let qkv = self.w_qkv.forward(x)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, self.num_heads, self.head_dim))?;let q = qkv.get_on_dim(2, 0)?;// Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)let q = q.transpose(1, 2)?.contiguous()?;let k = qkv.get_on_dim(2, 0)?;let k = k.transpose(1, 2)?.contiguous()?;let v = qkv.get_on_dim(2, 0)?;let v = v.transpose(1, 2)?.contiguous()?;let attn_scores = q.matmul(&k.transpose(2, 3)?)?;let mask = self.mask.broadcast_as(attn_scores.shape())?;let attn_scores = masked_fill(&attn_scores, &mask, f32::NEG_INFINITY)?;let attn_scores = attn_scores.broadcast_div(&self.d_model.sqrt()?)?;let softmax_dim = attn_scores.rank() - 1;// let attn_weights = ops::softmax_last_dim(&attn_scores)?;  //如果是cpu,可以用这个let attn_weights = ops::softmax(&attn_scores, softmax_dim)?;let attn_weights = self.dropout.forward(&attn_weights, train)?;let attn_output = attn_weights.matmul(&v)?;let attn_output = attn_output.transpose(1, 2)?;let attn_output = attn_output.reshape(&[batch_size, seq_len, self.num_heads*self.head_dim])?;let attn_output = self.out_proj.forward(&attn_output)?;Ok(attn_output)}

测试:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;let model = MultiHeadAttention::new(vb.clone(), 3, 4, 6, 2, 0.1, device.clone())?;let output = model.forward(&input, true)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

http://www.xdnf.cn/news/371899.html

相关文章:

  • Skyvern:用 AI+视觉驱动浏览器自动化
  • FreeTex v0.2.0:功能升级/支持Mac
  • Ubuntu 22.04(WSL2)使用 Docker 安装 Zipkin 和 Skywalking
  • 【含文档+PPT+源码】基于微信小程序的社区便民防诈宣传系统设计与实现
  • 基本句子结构
  • 前端取经路——现代API探索:沙僧的通灵法术
  • 每天五分钟机器学习:KTT条件
  • 在 Excel 中有效筛选重复元素
  • Stable Diffusion XL 文生图
  • 【金仓数据库征文】金融行业中的国产化数据库替代应用实践
  • C语言的中断 vs Java/Kotlin的异常:底层机制与高级抽象的对比
  • 365打卡第R8周: RNN实现阿尔茨海默病诊断
  • RAG 2.0 深入解读
  • 内存、磁盘、CPU区别,Hadoop/Spark与哪个联系密切
  • 海盗王64位服务端+32位客户端3.0版本
  • k8s删除pv和pvc后,vg存储没释放分析
  • Leetcode (力扣)做题记录 hot100(543,102,35,101)
  • AI:PS软件:ps软件中如何使用人工智能(AI)?
  • SierraNet协议分析使用指导[RDMA]| 如何设置 NVMe QP 端口以进行正确解码
  • 画立方体软件开发笔记 js three 投影 参数建模 旋转相机 @tarikjabiri/dxf导出dxf
  • 代码随想录第41天:图论2(岛屿系列)
  • Git简介和发展
  • 代码复用与分层
  • 双目视觉系统中,极线校正(Epipolar Rectification)与单应性矩阵/多平面单应性模型
  • 通过推测搜索加速大型语言模型推理 (SpecSearch) 论文总结
  • 零基础入门MySQL:10分钟搞定数据库基本操作
  • tryhackme——Enumerating Active Directory
  • 【Linux】冯诺依曼体系结构和操作系统的理解
  • Webug4.0通关笔记25- 第30关SSRF
  • JS较底层的用法,几类简单介绍