当前位置: 首页 > news >正文

如何用熵正则化控制注意力分数的分布

先写一个CrossAttention模块,

# input: Q(B, L, d), KV(B, N, d)
# output: (B, L, dim)
# 0<alpha<=ln(N), alpha越接近0, 注意力分数越逼近one-hot分布
class CrossAttention(layers.Layer):def __init__(self, num_head, dim, alpha,**kwargs):super().__init__(**kwargs)self.alpha = alphaself.num_head = num_headself.dim = dimself.layernorm = layers.LayerNormalization()def build(self, input_shape):self.qdk = self.add_weight(name='query_dense_kernel', shape=[input_shape[0][-1], self.num_head, self.dim])self.kdk = self.add_weight(name='key_dense_kernel', shape=[input_shape[1][-1], self.num_head, self.dim])self.vdk = self.add_weight(name='value_dense_kernel', shape=[input_shape[1][-1], self.num_head, self.dim])self.odk = self.add_weight(name='output_dense_kernel', shape=[self.num_head, self.dim, self.dim])self.odb = self.add_weight(name='output_dense_bais', shape=[self.dim])def call(self, inputs, *args, **kwargs):Q, KV = inputsquery = tf.einsum("abc, cde->abde", Q, self.qdk)key = tf.einsum("abc, cde->abde", KV, self.kdk)value = tf.einsum("abc, cde->abde", KV, self.vdk)query = tf.multiply(query, 1.0 / tf.math.sqrt(float(self.dim)))attention_scorces = tf.math.softmax(tf.einsum("abcd, aecd->acbe", query, key))self.add_loss(tf.reduce_mean((-tf.reduce_sum(attention_scorces * tf.math.log(attention_scorces + 1e-07), axis=-1) - self.alpha)**2))attention_output = tf.einsum("abcd, aceb->aecd", value, attention_scorces)output = tf.einsum("abcd, cdd->abd", attention_output, self.odk) + self.odbreturn self.layernorm(output + Q), attention_scorces

损失函数包含两种类型:prediction loss和regularization loss。

regularization loss需要add_loss方法进行添加,add_loss方法添加的损失值可以通过model.losses进行访问,返回一个集合,集合每个元素对应一个正则损失。

regularization loss被add_loss方法添加后,需要被tf.GradientTape()的作用域包含

def fit(x, y, epochs, model):optimizer = tf.keras.optimizers.Adam()for epoch in range(epochs):print("\nStart of epoch %d" % (epoch,))with tf.GradientTape() as tape:logits = model(x, training=True)[0]# Compute the loss value for this minibatch.loss_value = tf.keras.losses.binary_crossentropy(y, logits)print(model.losses)loss_value += sum(model.losses)grads = tape.gradient(loss_value, model.trainable_weights)# Run one step of gradient descent by updating# the value of the variables to minimize the loss.optimizer.apply_gradients(zip(grads, model.trainable_weights))print("attention scores entropy loss: %s" % (sum(model.losses)))print("loss" % loss_value)

接下来,设置一个简单的任务和数据看看熵正则化的效果,

class model(tf.keras.Model):def __init__(self):super().__init__()self.CA = CrossAttention(3, 16, 0.2)def build(self, input_shape):self.k = self.add_weight(name="predict_kernel", shape=[input_shape[0][-2], 16, 2])def call(self, inputs, *args, **kwargs):x, scores = self.CA(inputs)return tf.math.sigmoid(tf.einsum("abc, bcd->ad", x, self.k)), scoresif __name__ == '__main__':Q = tf.random.uniform((1, 4, 16))KV = tf.random.uniform((1, 6, 16))labels = tf.constant([[0., 1]])model = model()fit((Q, KV), labels, epochs=1000, model=model)print(model((Q, KV)))

模型训练好后,打印注意力分数的分布情况,可以发现每一行注意力分数都接近one-hot分布。

一个概率分布的信息熵最小值为0,最大值为logk。最小值对应熵最小的one-hot分布,最大值对应熵最大的均匀分布。在这里设置的熵正则化损失函数为(Entropy(scorces)-alpha)^2,通过调整alpha的大小,可以控制注意力分数逼近one-hot分布的程度。

<tf.Tensor: shape=(1, 3, 4, 6), dtype=float32, numpy=
array([[[[3.27099487e-03, 2.51237601e-02, 9.59103048e-01, 6.17671618e-03, 4.41661570e-03, 1.90874794e-03],[3.13497148e-03, 2.45431308e-02, 9.60317731e-01, 5.98660251e-03, 4.21733223e-03, 1.80015271e-03],[3.40689556e-03, 2.57629622e-02, 9.57871556e-01, 6.44017057e-03, 4.52079810e-03, 1.99752697e-03],[1.10661890e-03, 1.27607975e-02, 9.81542170e-01, 2.43383530e-03, 1.57478068e-03, 5.81745524e-04]],[[3.10348310e-02, 8.87579299e-05, 4.77730093e-04, 1.46521628e-03, 9.52835977e-01, 1.40975416e-02],[2.63910089e-02, 5.77273713e-05, 3.39534425e-04, 1.06126023e-03, 9.60522711e-01, 1.16277682e-02],[3.12446002e-02, 9.48462111e-05, 5.04475611e-04, 1.46811537e-03, 9.52523112e-01, 1.41648324e-02],[1.50452955e-02, 1.21340790e-05, 9.36727956e-05, 3.67841218e-04, 9.78706181e-01, 5.77491429e-03]],[[2.43717595e-03, 2.99123675e-02, 9.57738578e-01, 8.62001721e-03, 6.63190091e-04, 6.28605427e-04],[2.37493636e-03, 2.91979928e-02, 9.58939075e-01, 8.24017916e-03, 6.44378248e-04, 6.03430963e-04],[2.84763589e-03, 3.26210111e-02, 9.53170657e-01, 9.78391431e-03, 8.07495147e-04, 7.69288861e-04],[1.05220091e-03, 1.88548081e-02, 9.75066125e-01, 4.56135161e-03, 2.38224253e-04, 2.27281591e-04]]]],dtype=float32)>

http://www.xdnf.cn/news/1414135.html

相关文章:

  • 【CVTE】C++开发 (提前批一面)
  • 【AI智能体】Dify 实现自然语言转SQL操作数据库实战详解
  • 【Spring】ApplicationListener监听器
  • 【芯片测试篇】:LIN总线
  • AI 赋能 Java 开发效率:全流程痛点解决与实践案例(一)
  • Linux/UNIX系统编程手册笔记:用户和组、进程凭证、时间以及系统限制和选项
  • 利用DeepSeek编写验证xlsx格式文件中是否启用sharedStrings.xml对读写效率影响python程序
  • DRF快速构建RESTful API指南
  • redis详解 (最开始写博客是写redis 纪念日在写一篇redis)
  • C++内存序不迷茫:从CPU缓存一致性理解Memory Order
  • Wi-Fi技术——初识
  • 如何绕过 disable-devtool.js 打开控制台
  • C语言中如何使用NULL
  • 配置 Kubernetes Master 节点不可调度的标准方法
  • stm32F4挂载emmc以及重定义printf
  • ThinkPHP8学习篇(五):数据库(一)
  • 洛谷p2392kkksc03考前临时抱佛脚 详解(回溯,深度搜索法)
  • Redis常见数据类型及应用场景
  • java 安装流程配置
  • 金仓数据库KingbaseES:中国自主原创的数据库领军者
  • 【四位加密】2022-10-25
  • GDPU操作系统实验:生产者消费者问题
  • 【读数笔记】《你的生存本能正在杀死你》
  • 经典卷积神经网络CNN
  • sublime MAC系统快捷键及常见问题
  • Qwen2.5-VL代码初步解读
  • 恒香全新旗舰店开幕 新店传承百年文化
  • 容器seccomp配置文件在云服务器安全策略中的实施规范
  • 常用定位技术对比解析
  • MySQL数据库——0.MySQL大纲