当前位置: 首页 > news >正文

深入解析PyTorch中MultiheadAttention的参数key_padding_mask与attn_mask

1. 基本背景

在multiheadattention中存在两个mask,一个参数是key_padding_mask,另外一个是attn_mask,尽管这两个参数是被人们所熟知的填充掩码和注意力掩码,但是深度理解以便清晰区分对于深刻理解该架构非常重要。

2. 参数Key_padding_mask(关键填充掩码)

  • 用途:防止模型关注到输入序列中用 <pad> 填充的位置。
  • 场景:对变长输入进行 padding 后,避免注意力将注意力权重分配到 padding token 上。
  • 应用位置:在计算注意力时,对 所有 query 的 key 位置 进行屏蔽。

✅维度

# key_padding_mask shape: (batch_size, seq_len)

✅ 示例

key_padding_mask = torch.tensor([[False, False, True], [False, True, True]])
# 表示第一个样本第3个位置是pad,第二个样本第2,3个位置是pad

3. 参数Attn_mask(注意力掩码)

  • 用途:对注意力矩阵中任意 query-key 对的连接进行屏蔽,更灵活。
  • 场景:
    • Transformer 解码器中的 自回归遮蔽(causal mask)
    • 限定注意力只能在局部范围内滑动(局部注意力)
    • 自定义 mask,如节省计算或实验结构

✅ 维度

# [tgt_len, src_len](用于所有 batch 和 head)
# 或 [batch_size * num_heads, tgt_len, src_len](用于每个 head 的个性化 mask)

✅ 示例:causal mask

# 上三角为 True,代表“未来的信息被屏蔽”,用于解码器自回归。
tgt_len = 5
attn_mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()

4. 工作流程中的区别⚠️⚠️⚠️

在计算 Q ∗ K T Q*K^T QKT之后:

  1. 先应用attn_mask(对齐注意力矩阵维度,屏蔽某些query-key配对);
  2. 再应用key_padding_mask(对每个样本的padding key屏蔽);
  3. 最后经过softmax处理

5. 类比理解

  • key_padding_mask 像是说:“这些 token 是 padding,不用关注它们。”
  • attn_mask 像是说:“这些 query-key 配对不允许有连接(比如未来的信息)。”
http://www.xdnf.cn/news/533989.html

相关文章:

  • 【AI时代】Java程序员大模型应用开发详细教程(上)
  • ALTER AGGREGATE使用场景
  • Pod 节点数量
  • 【Game】Powerful——Punch and Kick(12)
  • 阿里世界偏好模型:WorldPM-72B论文速读
  • LangChain框架核心技术:从链式工作流到结构化输出的全栈指南
  • Spring的后置处理器是干什么用的?扩展点又是什么?
  • 数据结构学习笔记—初识数据结构
  • 用Caffeine和自定义注解+AOP优雅实现本地防抖接口限流
  • 玉米籽粒发育
  • spring boot 注解 @bean
  • 打卡30天
  • 【IDEA】删除/替换文件中所有包含某个字符串的行
  • ROS2简介
  • 关于ECMAScript的相关知识点!
  • 适合学人工智能的专业有哪些?
  • 【算法】滑动窗口动态查找不含重复字符的最长子串
  • 同一颗太阳:Australia、Austria、Arab、Africa、Augustus、August、Aurora、Athena
  • input组件使用type=“number“的时候,光标自动跳到首位
  • 深度学习基础——神经网络优化算法
  • 免费私有化部署! PawSQL社区版,超越EverSQL的企业级SQL优化工具面向个人开发者开放使用了
  • 游戏盾的功有哪些?
  • AGI大模型(27):LangChain向量存储
  • react事件绑定的方法
  • 桌面麒麟系统下的GMAC调试日志
  • HTTPS、SSL证书是啥?网站“安全小锁”的入门科普
  • 基于 STC89C52 的料仓物位监测系统设计与实现
  • 自动化调参工具:VOFA+可视化参数
  • java集合详细讲解
  • Java集合框架解析:从基础到底层源码