loss = -F.log_softmax(logits[:, -1, :], dim=1)[0, irrational_id]
loss = -F.log_softmax(logits[:, -1, :], dim=1)[0, irrational_id]
解释这行代码 loss = -F.log_softmax(logits[:, -1, :], dim=1)[0, irrational_id]
的具体含义和背后的数学原理。
1. 代码拆解与功能
这行代码的核心是计算负对数似然损失(Negative Log Likelihood Loss),用于引导模型生成特定标签(这里是“偏激”)。我们逐步拆解:
1.1 logits[:, -1, :]
logits
:模型输出的原始预测分数,形状通常为[batch_size, seq_len, vocab_size]
。[:, -1, :]
:取最后一个时间步(token)的预测分数,形状变为[batch_size, voca