当前位置: 首页 > news >正文

自注意力机制、多头自注意力机制、填充掩码 Python实现

原理讲解

【Transformer系列(2)】注意力机制、自注意力机制、多头注意力机制、通道注意力机制、空间注意力机制超详细讲解

自注意力机制

import torch
import torch.nn as nn# 自注意力机制
class SelfAttention(nn.Module):def __init__(self, input_dim):super(SelfAttention, self).__init__()self.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)        def forward(self, x, mask=None):batch_size, seq_len, input_dim = x.shapeq = self.query(x)k = self.key(x)v = self.value(x)atten_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(input_dim, dtype=torch.float))if mask is not None:mask = mask.unsqueeze(1)attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        atten_scores = torch.softmax(atten_weights, dim=-1)attented_values = torch.matmul(atten_scores, v)return attented_values# 自动填充函数
def pad_sequences(sequences, max_len=None):batch_size = len(sequences)input_dim = sequences[0].shape[-1]lengths = torch.tensor([seq.shape[0] for seq in sequences])max_len = max_len or lengths.max().item()padded = torch.zeros(batch_size, max_len, input_dim)for i, seq in enumerate(sequences):seq_len = seq.shape[0]padded[i, :seq_len, :] = seqmask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)return padded, mask.long()if __name__ == '__main__':batch_size = 2seq_len = 3input_dim = 128seq_len_1 = 3seq_len_2 = 5x1 = torch.randn(seq_len_1, input_dim)            x2 = torch.randn(seq_len_2, input_dim)target_seq_len = 10    padded_x, mask = pad_sequences([x1, x2], target_seq_len)selfattention = SelfAttention(input_dim)    attention = selfattention(padded_x)print(attention)

多头自注意力机制

import torch
import torch.nn as nn# 定义多头自注意力模块
class MultiHeadSelfAttention(nn.Module):def __init__(self, input_dim, num_heads):super(MultiHeadSelfAttention, self).__init__()self.num_heads = num_headsself.head_dim = input_dim // num_headsself.query = nn.Linear(input_dim, input_dim)self.key = nn.Linear(input_dim, input_dim)self.value = nn.Linear(input_dim, input_dim)        def forward(self, x, mask=None):batch_size, seq_len, input_dim = x.shape# 将输入向量拆分为多个头## transpose(1,2)后变成 (batch_size, self.num_heads, seq_len, self.head_dim)形式q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# 计算注意力权重attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))# 应用 padding maskif mask is not None:# mask: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 用于广播mask = mask.unsqueeze(1).unsqueeze(2)  # 扩展维度以便于广播attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        attn_scores = torch.softmax(attn_weights, dim=-1)# 注意力加权求和attended_values = torch.matmul(attn_scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim)return attended_values# 自动填充函数
def pad_sequences(sequences, max_len=None):batch_size = len(sequences)input_dim = sequences[0].shape[-1]lengths = torch.tensor([seq.shape[0] for seq in sequences])max_len = max_len or lengths.max().item()padded = torch.zeros(batch_size, max_len, input_dim)for i, seq in enumerate(sequences):seq_len = seq.shape[0]padded[i, :seq_len, :] = seqmask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)return padded, mask.long()if __name__ == '__main__':heads = 2batch_size = 2seq_len_1 = 3seq_len_2 = 5input_dim = 128x1 = torch.randn(seq_len_1, input_dim)            x2 = torch.randn(seq_len_2, input_dim)target_seq_len = 10    padded_x, mask = pad_sequences([x1, x2], target_seq_len)multiheadattention = MultiHeadSelfAttention(input_dim, heads)attention = multiheadattention(padded_x, mask)    print(attention)
http://www.xdnf.cn/news/65197.html

相关文章:

  • 如何在白平衡标定种构建不同类型的白平衡色温坐标系
  • Android 音悦适配-v4.3.3-可在线播放可下载音乐的第三方APP
  • 【解决方法】关于解决QGC地面站4.4.3中文BUG,无法标注航点的问题
  • 模型检测技术的发展历史简介
  • Redis基础知识
  • 山东大学软件学院创新项目实训开发日志(21)之小问题修复之对话方向修改
  • 工厂模式:简单工厂模式
  • 免费送源码:ava+springboot+MySQL 基于springboot 宠物医院管理系统的设计与实现 计算机毕业设计原创定制
  • 修改IP地址能否精确到地级市的县?——全面解析
  • 第39讲|决策树与作物分布建模:可解释的AI助力农业智能推演
  • WINDOWS下使用命令行读取本地摄像头FFMPEG+DirectShow,ffplay直接播放摄像头数据
  • arkTs:使用Refresh实现下拉刷新功能(含状态提示与动画控制)
  • MySQL知识点讲解
  • n2n 搭建虚拟局域网,实现内网穿透
  • C++计算 n! 中末尾零的数量
  • RIP动态路由(三层交换机+单臂路由)
  • 20250421在荣品的PRO-RK3566开发板的Android13下频繁重启RKNPU fde40000.npu: Adding to iommu gr
  • Java学习路线--自用--带链接
  • win11修改文件后缀名
  • ADB->查看某个应用的版本信息
  • 力扣面试150题--插入区间和用最少数量的箭引爆气球
  • TensorBoard如何在同一图表中绘制多个线条
  • 第 7 篇:总结与展望 - 时间序列学习的下一步
  • k8s集群GPU监控项说明
  • 【人工智能】使用vLLM高性能部署Qwen大语言模型
  • 10天学会嵌入式技术之51单片机-day-2
  • LVGL学习(一)(IMX6ULL运行LVGL,lv_obj_t,lv_obj,size,position,border-box,styles,events)
  • 4·25调价倒计时:SheinTemu美国站价格策略3大关键解读
  • 多路由器通过RIP动态路由实现通讯(单臂路由)
  • Dify忘记管理员密码,重置的问题