当前位置: 首页 > news >正文

[machine learning] Transformer - Attention (三)

上一篇文章我们介绍实现了带训练参数的self-attention类。本文在上一篇文章基础上进一步介绍因果attention(causal attention)。

Causal attention,也叫masked attention,是self-attention的一种特殊形式。跟标准的self-attention考虑输入序列的全部单词不同,causal attention只允许考虑当前单词之前的单词(即屏蔽当前单词未来的单词)。下面是屏蔽过程的示意图:
在这里插入图片描述
下面是给causal attention添加mask的一个例子:

# Reuse the query and key weight matrices of the
# 这里使用上一篇文章介绍的 SelfAttention_v2 类
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.Tattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)
# 输出:
#tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
#        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
#        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
#        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
#        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
#        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
#       grad_fn=<SoftmaxBackward0>)context_length = attn_scores.shape[0]
# 屏蔽上三角,即每个单词只能看到它之前的单词
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
# 输出:
#tensor([[1., 0., 0., 0., 0., 0.],
#        [1., 1., 0., 0., 0., 0.],
#        [1., 1., 1., 0., 0., 0.],
#        [1., 1., 1., 1., 0., 0.],
#        [1., 1., 1., 1., 1., 0.],
#        [1., 1., 1., 1., 1., 1.]])# 屏蔽后的attention weights
masked_simple = attn_weights*mask_simple
print(masked_simple)
# 输出
#tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
#        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
#        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
#        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
#        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
#       grad_fn=<MulBackward0>)# 归一化屏蔽后的attention weights,使得每一行和为1
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
# 输出:
# tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#       [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
#        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
#        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
#        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
#        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
#       grad_fn=<DivBackward0>)

上面的这个例子是做完Softmax之后再加mask,这打破了Softmax创建的概率分布,因此需要重新归一化使得每一行的和重新为1。这种做法虽然最后依然能够使得每一行和为1,但是重新归一化的操作增加了复杂度并可能引入意想不到的效果。

下面介绍一种利用Softmax的特性一步到位的方法:

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
# 输出:
#tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
#        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
#        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
#        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
#        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
#        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
#       grad_fn=<MaskedFillBackward0>)attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)
# 输出:
#tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
#        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
#        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
#        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
#        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
#       grad_fn=<SoftmaxBackward0>)

这种方法把需要屏蔽的值设为负无穷,而对于负无穷,softmax会转换成0。因此,巧妙地既把屏蔽值权重降为0,又保证每一行的概率和为1。

下面介绍怎么在causal attention的计算中添加dropoutdropout是用来防止模型过拟合,并且只在训练的时候用,推理的时候不用dropout。在transformer架构中,attention机制中的dropout主要有两种用法:一种是在attention scores上做dropout;另一种是在attention weights上做dropout。目前,主流的做法是在attention weights上做dropout,示意图如下:
在这里插入图片描述
下面在一个6×6的全1矩阵上做dropout,以这个简单例子介绍dropout机制:

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # dropout rate of 50%
example = torch.ones(6, 6) # create a matrix of onesprint(dropout(example))
# 输出:
#tensor([[2., 2., 0., 2., 2., 0.],
#        [0., 0., 0., 2., 0., 2.],
#        [2., 2., 2., 2., 0., 2.],
#        [0., 2., 2., 0., 0., 2.],
#        [0., 2., 0., 2., 0., 2.],
#        [0., 2., 2., 2., 2., 0.]])

可以看到,有一半的值被drop了(即被赋值0)。还有一个注意点是,dropout之后,没有被drop的值数值也改变了。这是因为,为了弥补有些值被drop了的影响,没有被drop的值会被放大(scale up)。这种放大机制对于维持attention weights的整体平衡至关重要,因为它确保了在训练和推理时attention机制的平均影响保持一致。没有被drop的值的放大公式是:1 / (1 - dropout_rate)

下面我们对之前的attention weights做dropout(在不同的操作系统上,结果会有不同):

torch.manual_seed(123)
print(dropout(attn_weights))# 输出:
#tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
#        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
#        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
#        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
#       grad_fn=<MulBackward0>)

了解了causal attention和dropout masking这两种机制后,下面我们实现一个causal attention类:

class CausalAttention(nn.Module):def __init__(self, d_in, d_out, context_length,dropout, qkv_bias=False):super().__init__()self.d_out = d_outself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.dropout = nn.Dropout(dropout) # Newself.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # Newdef forward(self, x):b, num_tokens, d_in = x.shape # New batch dimension b# For inputs where `num_tokens` exceeds `context_length`, this will result in errors# in the mask creation further below.# In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  # do not exceed `context_length` before reaching this forward method. keys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.transpose(1, 2) # Changed transposeattn_scores.masked_fill_(  # New, _ ops are in-placeself.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_sizeattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights) # Newcontext_vec = attn_weights @ valuesreturn context_vec

在这个类中,新加了self.register_buffer()方法,这个方法不是必须的,但是使用它的好处是:buffers会自动移到模型使用的device上(CPU或GPU),这样我们就不必手动去确保和模型的device一致,避免了device不一致的错误。


下面我们使用CausalAttention类来计算上下文向量:

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3
# 输出:
# torch.Size([2, 6, 3])torch.manual_seed(123)context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)context_vecs = ca(batch)print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
# 输出:
#tensor([[[-0.4519,  0.2216],
#         [-0.5874,  0.0058],
#         [-0.6300, -0.0632],
#         [-0.5675, -0.0843],
#         [-0.5526, -0.0981],
#         [-0.5299, -0.1081]],#        [[-0.4519,  0.2216],
#         [-0.5874,  0.0058],
#         [-0.6300, -0.0632],
#         [-0.5675, -0.0843],
#         [-0.5526, -0.0981],
#         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
# context_vecs.shape: torch.Size([2, 6, 2])

这里简单使用了torch.stack来模拟batch输入。




参考资料:
《Build a Large Language Model from Scratch》

http://www.xdnf.cn/news/290899.html

相关文章:

  • C++ 检查某个点是否存在于圆扇区内(Check whether a point exists in circle sector or not)
  • 2025流感疫苗指南+卫健委诊疗方案|高危人群防护+并发症处理 慢性肾脏病饮食指南2025卫健委版|低盐低磷食谱+中医调理+PDF 网盘下载 pdf下载
  • Scala day6(Class,field,Single Object)
  • EPSG:3857 和 EPSG:4326 的区别
  • 掌纹图像识别:解锁人类掌纹/生物识别的未来——技术解析与前沿数据集探索
  • 2025系统架构师---论软件的设计模式论文
  • Java按字节长度截取字符串指南
  • JVM——Java对象的内存布局
  • Hive安装与配置教程
  • 详讲viewer查看器
  • Astro Canvas 数据中心→设备一览大屏操作指南
  • 基于 HTML5 的贪吃蛇小游戏实现
  • Oracle数据库从入门到掌握基础应用能力
  • 16. Qt系统相关:事件、定时器
  • 金融的本质是智融、融资的实质是融智、投资的关键是投智,颠覆传统金融学的物质资本中心论,构建了以智力资本为核心的新范式
  • 启发式算法-禁忌搜索算法
  • Python学习之路(七)-绘画and动画
  • 使用 JavaScript 实现数据导出为 Excel 和 CSV 文件
  • Ultra7-265K 和 技嘉Z890M-AORUS-ELITE-WIFI7主板 简单开箱测评
  • 《Python星球日记》第29天:Flask进阶
  • Unity-Shader详解-其四
  • python计算shp中每个区域的面积
  • Linux 怎么使用局域网内电脑的网络访问外部
  • android-ndk开发(6): 查看反汇编
  • 《算法导论(第4版)》阅读笔记:p7-p8
  • 售前赢单评分是越权吗?
  • 第二章-猜数游戏
  • 数据集-目标检测系列- 牙刷 检测数据集 toothbrush >> DataBall
  • 分析strtol(),strtoul()和strtod()三个函数的功能
  • 字符串哈希专题