首先是位置编码组件:

import torch
import torch.nn as nn
import mathclass PositonalEncoding(nn.Module):def __init__ (self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# [[1, 2, 3],# [4, 5, 6],# [7, 8, 9]]pe = torch.zeros(max_len, d_model)# [[0],# [1],# [2]]position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 位置编码固定,不更新参数# 保存模型时会保存缓冲区,在引入模型时缓冲区也被引入self.register_buffer('pe', pe)def forward(self, x):# 不计算梯度x = x + self.pe[:, :x.size(1)].requires_grad_(False)
