如何用熵正则化控制注意力分数的分布
先写一个CrossAttention模块,
# input: Q(B, L, d), KV(B, N, d)
# output: (B, L, dim)
# 0<alpha<=ln(N), alpha越接近0, 注意力分数越逼近one-hot分布
class CrossAttention(layers.Layer):def __init__(self, num_head, dim, alpha,**kwargs):super().__init__(**kwargs)self.alpha = alphaself.num_head = num_headself.dim = dimself.layernorm = layers.LayerNormalization()def build(self, input_shape):self.qdk = self.add_weight(name='query_dense_kernel', shape=[input_shape[0][-1], self.num_head, self.dim])self.kdk = self.add_weight(name='key_dense_kernel', shape=[input_shape[1][-1], self.num_head, self.dim])self.vdk = self.add_weight(name='value_dense_kernel', shape=[input_shape[1][-1], self.num_head, self.dim])self.odk = self.add_weight(name='output_dense_kernel', shape=[self.num_head, self.dim, self.dim])self.odb = self.add_weight(name='output_dense_bais', shape=[self.dim])def call(self, inputs, *args, **kwargs):Q, KV = inputsquery = tf.einsum("abc, cde->abde", Q, self.qdk)key = tf.einsum("abc, cde->abde", KV, self.kdk)value = tf.einsum("abc, cde->abde", KV, self.vdk)query = tf.multiply(query, 1.0 / tf.math.sqrt(float(self.dim)))attention_scorces = tf.math.softmax(tf.einsum("abcd, aecd->acbe", query, key))self.add_loss(tf.reduce_mean((-tf.reduce_sum(attention_scorces * tf.math.log(attention_scorces + 1e-07), axis=-1) - self.alpha)**2))attention_output = tf.einsum("abcd, aceb->aecd", value, attention_scorces)output = tf.einsum("abcd, cdd->abd", attention_output, self.odk) + self.odbreturn self.layernorm(output + Q), attention_scorces
损失函数包含两种类型:prediction loss和regularization loss。
regularization loss需要add_loss方法进行添加,add_loss方法添加的损失值可以通过model.losses进行访问,返回一个集合,集合每个元素对应一个正则损失。
regularization loss被add_loss方法添加后,需要被tf.GradientTape()的作用域包含
def fit(x, y, epochs, model):optimizer = tf.keras.optimizers.Adam()for epoch in range(epochs):print("\nStart of epoch %d" % (epoch,))with tf.GradientTape() as tape:logits = model(x, training=True)[0]# Compute the loss value for this minibatch.loss_value = tf.keras.losses.binary_crossentropy(y, logits)print(model.losses)loss_value += sum(model.losses)grads = tape.gradient(loss_value, model.trainable_weights)# Run one step of gradient descent by updating# the value of the variables to minimize the loss.optimizer.apply_gradients(zip(grads, model.trainable_weights))print("attention scores entropy loss: %s" % (sum(model.losses)))print("loss" % loss_value)
接下来,设置一个简单的任务和数据看看熵正则化的效果,
class model(tf.keras.Model):def __init__(self):super().__init__()self.CA = CrossAttention(3, 16, 0.2)def build(self, input_shape):self.k = self.add_weight(name="predict_kernel", shape=[input_shape[0][-2], 16, 2])def call(self, inputs, *args, **kwargs):x, scores = self.CA(inputs)return tf.math.sigmoid(tf.einsum("abc, bcd->ad", x, self.k)), scoresif __name__ == '__main__':Q = tf.random.uniform((1, 4, 16))KV = tf.random.uniform((1, 6, 16))labels = tf.constant([[0., 1]])model = model()fit((Q, KV), labels, epochs=1000, model=model)print(model((Q, KV)))
模型训练好后,打印注意力分数的分布情况,可以发现每一行注意力分数都接近one-hot分布。
一个概率分布的信息熵最小值为0,最大值为。最小值对应熵最小的one-hot分布,最大值对应熵最大的均匀分布。在这里设置的熵正则化损失函数为
,通过调整alpha的大小,可以控制注意力分数逼近one-hot分布的程度。
<tf.Tensor: shape=(1, 3, 4, 6), dtype=float32, numpy=
array([[[[3.27099487e-03, 2.51237601e-02, 9.59103048e-01, 6.17671618e-03, 4.41661570e-03, 1.90874794e-03],[3.13497148e-03, 2.45431308e-02, 9.60317731e-01, 5.98660251e-03, 4.21733223e-03, 1.80015271e-03],[3.40689556e-03, 2.57629622e-02, 9.57871556e-01, 6.44017057e-03, 4.52079810e-03, 1.99752697e-03],[1.10661890e-03, 1.27607975e-02, 9.81542170e-01, 2.43383530e-03, 1.57478068e-03, 5.81745524e-04]],[[3.10348310e-02, 8.87579299e-05, 4.77730093e-04, 1.46521628e-03, 9.52835977e-01, 1.40975416e-02],[2.63910089e-02, 5.77273713e-05, 3.39534425e-04, 1.06126023e-03, 9.60522711e-01, 1.16277682e-02],[3.12446002e-02, 9.48462111e-05, 5.04475611e-04, 1.46811537e-03, 9.52523112e-01, 1.41648324e-02],[1.50452955e-02, 1.21340790e-05, 9.36727956e-05, 3.67841218e-04, 9.78706181e-01, 5.77491429e-03]],[[2.43717595e-03, 2.99123675e-02, 9.57738578e-01, 8.62001721e-03, 6.63190091e-04, 6.28605427e-04],[2.37493636e-03, 2.91979928e-02, 9.58939075e-01, 8.24017916e-03, 6.44378248e-04, 6.03430963e-04],[2.84763589e-03, 3.26210111e-02, 9.53170657e-01, 9.78391431e-03, 8.07495147e-04, 7.69288861e-04],[1.05220091e-03, 1.88548081e-02, 9.75066125e-01, 4.56135161e-03, 2.38224253e-04, 2.27281591e-04]]]],dtype=float32)>