当前位置: 首页 > web >正文

自定义SamOut模型在随机序列生成任务上超越Transformer

引言

在序列建模领域,Transformer架构凭借其强大的注意力机制已成为主流选择。然而,本文展示了一种名为SamOut的新型模型架构,在随机序列生成任务上显著超越了传统Transformer模型。通过对比实验,我将详细分析两种模型在相同任务上的表现差异。

实验设置

任务描述

我们设计了一个随机序列生成任务,要求模型预测序列中的下一个token。数据包含三种模式:

  1. 递增序列(0,1,2,3,…)
  2. 重复序列(0,1,2,3,4,0,1,2,…)
  3. 局部模式序列(每段使用相同token)

模型架构

  1. Transformer模型

    • 标准Transformer编码器结构
    • 2层编码器,8个注意力头
    • 隐藏层维度512
  2. SamOut模型

    • 自定义MaxStateSuper注意力机制
    • 2层解码器结构
    • 可学习的混合参数(α1-α4)
    • 门控前馈网络

训练参数

  • 词表大小:12506
  • 批量大小:32
  • 序列长度:50
  • 训练轮数:50
  • 优化器:Adam(lr=0.001)

实验结果对比

1. 模型效率

指标TransformerSamOut
参数量17,011,71216,481,290
训练时间47.64秒43.19秒
最终训练损失7.32493.1282

SamOut在参数更少的情况下,训练速度更快且收敛效果更好。

2. 准确率表现

samout
在这里插入图片描述

  • 初始准确率

    • Transformer: 28.57%
    • SamOut: 67.98%
  • 最终准确率

    • Transformer: 50.00%
    • SamOut: 67.98%

SamOut从训练开始就保持显著优势,最终准确率比Transformer高出近18个百分点。

3. 损失分析

Transformer的验证损失从8.47降至7.31,而SamOut从3.33降至3.06,且始终保持在更低水平。

4. 生成示例

输入序列:[0, 1, 2, 3, 4]

Transformer生成
[0,1,2,3,4,13,23,26,43,32,32,32,40,38,8,11,44,13,40,2251,2251,6983,11,44,40]

SamOut生成
[0,1,2,3,4,3557,6592,3928,3625,8038,12288,7408,5932,11470,9788,9583,1425,8498,344,3350,5421,4488,8998,3827,8188]

SamOut生成的序列展现出更好的多样性和更合理的数值分布。

性能差距分析

1. 注意力机制创新

SamOut的MaxStateSuper注意力机制结合了四种不同的计算路径:

term1 = a * b
term2 = self.alpha1 * b + self.alpha2 * d
term3 = a * (self.alpha3 * e + d)
term4 = b * (c + e)
return term1 + term2 + term3 + term4 + c * e

这种混合计算方式比标准Transformer的点积注意力更灵活,能捕捉更复杂的序列模式。

2. 门控前馈网络

SamOut的前馈网络包含门控机制:

x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
xx = x1 * x2
x = self.ffn2(xx)

这种设计增强了模型对序列特征的非线性变换能力。

3. 参数效率

SamOut通过共享线性变换和参数化混合系数,在减少参数量的同时提升了表达能力:

self.combined = nn.Linear(dim_size, 4 * dim_size, bias=False)
self.alpha1 = nn.Parameter(torch.tensor(0.5))

4. 残差连接优化

SamOut在每个解码层后使用残差连接:

x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)

这种设计缓解了梯度消失问题,加速了训练过程。

结论

在随机序列生成任务上,SamOut模型展现出显著优势:

  1. 训练速度快9.3%
  2. 参数减少3.1%
  3. 准确率提高35.96%(相对提升)
  4. 损失降低58.3%

这些结果表明,针对特定任务设计的定制化架构可以超越通用Transformer模型。SamOut的创新点在于其混合注意力机制和门控前馈网络,这些设计使其能更有效地捕捉序列中的模式。

未来工作可以探索SamOut架构在自然语言处理、时间序列预测等领域的应用潜力,以及如何将其注意力机制与Transformer结合以获得更好的性能。

“创新不是淘汰旧架构,而是找到最适合问题的新组合。” - 本文实验证明了这一点

samout

import time
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as pltclass MaxStateSuper(nn.Module):def __init__(self, dim_size, heads):super(MaxStateSuper, self).__init__()self.heads = headsassert dim_size % heads == 0, "Dimension size must be divisible by head size."self.combined = nn.Linear(dim_size, 4 * dim_size, bias=False)self.alpha1 = nn.Parameter(torch.tensor(0.5))self.alpha2 = nn.Parameter(torch.tensor(0.5))self.alpha3 = nn.Parameter(torch.tensor(0.5))self.alpha4 = nn.Parameter(torch.tensor(0.5))def forward(self, x, state=None):b, s, d = x.shapecombined = self.combined(x).view(b, s, 4, self.heads, -1)out, out1, out2, out3 = combined.unbind(2)out = out.permute(0, 3, 1, 2)out1 = out1.permute(0, 3, 1, 2)out2 = out2.permute(0, 3, 1, 2)out3 = out3.permute(0, 3, 1, 2)out4, _ = torch.cummax(out2, dim=2)out = self.gen_model(out, out1, out2, out3, out4)out = out.transpose(1, 2).contiguous().view(b, s, d)return out, statedef gen_model(self, a, b, c, d, e):term1 = a * bterm2 = self.alpha1 * b + self.alpha2 * dterm3 = a * (self.alpha3 * e + d)term4 = b * (c + e)return term1 + term2 + term3 + term4 + c * eclass FeedForward(nn.Module):def __init__(self, hidden_size):super(FeedForward, self).__init__()self.ffn1 = nn.Linear(hidden_size, hidden_size)self.ffn2 = nn.Linear(hidden_size, hidden_size)self.gate = nn.Linear(hidden_size, hidden_size)self.relu = nn.ReLU()def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))xx = x1 * x2x = self.ffn2(xx)return xclass DecoderLayer(nn.Module):def __init__(self, hidden_size, num_heads):super(DecoderLayer, self).__init__()self.self_attention = MaxStateSuper(hidden_size, num_heads)self.ffn = FeedForward(hidden_size)self.layer_norm = nn.LayerNorm(hidden_size)self.alpha = nn.Parameter(torch.tensor(0.5))def forward(self, x, state=None):x1, state = self.self_attention(x, state)x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)return x, stateclass SimplifiedSinusoidalLoss(nn.Module):"""简化版损失函数:只关注目标位置的余弦值"""def __init__(self, voc_size, epsilon=1e-8):super().__init__()# 创建词表索引的弧度映射:j -> 2πj/voc_sizeself.theta = nn.Parameter(torch.cat([torch.cat([torch.linspace(0.0001,torch.pi/2,i).reshape([1,-1]),torch.linspace(torch.pi/2+0.2,torch.pi,voc_size+1-i).reshape([1,-1])],-1) for i in range(1,voc_size+1)],0)[:,1:],requires_grad=False)self.epsilon = epsilondef forward(self, angles, targets):"""angles: 模型输出的角度 (batch, seq_len, voc_size)targets: 目标token索引 (batch, seq_len)"""# 获取目标位置对应的角度预测值# 使用gather操作:从voc_size维度中,根据targets的索引取值# selected_angles = angles.gather(dim=2, index=targets.unsqueeze(-1)).squeeze(-1)# 计算目标位置对应的余弦值# cos_values = torch.cos(selected_angles)# 损失:使目标位置的余弦值接近0(即达到极值点)# loss = torch.mean(cos_values ** 2)loss=torch.mean(angles-self.theta[targets])**2return lossclass SamOut(nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super(SamOut, self).__init__()self.em = nn.Embedding(voc_size, hidden_size, padding_idx=0)self.decoder_layers = nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])# 添加正弦输出层self.output =nn.Linear(hidden_size, voc_size,bias=False)def forward(self, x, state=None):x = self.em(x)if state is None:state = [None] * len(self.decoder_layers)for i, layer in enumerate(self.decoder_layers):x, state[i] = layer(x, state[i])x = x + x  # 残差连接# 返回角度值x = self.output(x)return x, statedef predict_next_token(model, input_sequence, temperature=1.0, top_k=50):"""带温度调节的预测下一个token"""model.eval()with torch.no_grad():angles, _ = model(input_sequence)last_angles = angles[:, -1, :]# 计算正弦值并应用温度调节sin_values = torch.sin(last_angles) / temperature# 应用softmax得到概率分布probs = torch.softmax(sin_values, dim=-1)# Top-k采样if top_k > 0:top_probs, top_indices = probs.topk(top_k, dim=-1)# 重新归一化top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)# 采样next_token = torch.multinomial(top_probs, num_samples=1)next_token = top_indices.gather(1, next_token)else:# 直接采样next_token = torch.multinomial(probs, num_samples=1)return next_token.squeeze(1)def generate_sequence(model, start_tokens, length=20, temperature=1.0, top_k=50):"""生成完整序列"""model.eval()with torch.no_grad():generated = start_tokens.clone()state = None  # 初始化隐藏状态for _ in range(length):input_seq = generated[:, -1:]  # 只使用最后一个token作为输入angles, state = model(input_seq, state)# 计算正弦值并应用温度调节sin_values = torch.sin(angles[:, -1, :]) / temperature# 应用softmax得到概率分布probs = torch.softmax(sin_values, dim=-1)# Top-k采样if top_k > 0:top_probs, top_indices = probs.topk(top_k, dim=-1)top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)next_token = torch.multinomial(top_probs, num_samples=1)next_token = top_indices.gather(1, next_token)else:next_token = torch.multinomial(probs, num_samples=1)generated = torch.cat((generated, next_token), dim=1)return generated#
# 修改 validate_model 函数
def validate_model(model, criterion, batch_size=16, seq_length=50, voc_size=12506):"""验证模型性能"""model.eval()with torch.no_grad():data = generate_sequence_data(batch_size, seq_length, voc_size)input_tensor = data[:, :-1]target_tensor = data[:, 1:]angles, _ = model(input_tensor)# 修复维度问题:将输出和目标重塑为交叉熵损失所需的形状loss = criterion(angles.reshape(-1, voc_size), target_tensor.reshape(-1))# 计算准确率# 使用 logits 而不是 sin_values 进行预测predictions = torch.argmax(angles, dim=-1)correct = (predictions == target_tensor).sum().item()total = target_tensor.numel()accuracy = correct / totalreturn loss.item(), accuracydef generate_sequence_data(batch_size, seq_length, voc_size):"""生成有意义的序列数据"""data = torch.zeros(batch_size, seq_length, dtype=torch.long)# 创建多种序列模式for i in range(batch_size):# 模式1: 递增序列if i % 3 == 0:data[i] = torch.arange(seq_length) % voc_size# 模式2: 重复序列elif i % 3 == 1:pattern = torch.tensor([0, 1, 2, 3, 4])repeats = (seq_length // len(pattern)) + 1data[i] = pattern.repeat(repeats)[:seq_length]# 模式3: 随机但有局部模式else:base = torch.randint(0, voc_size, (seq_length // 5,))for j in range(5):start = j * (seq_length // 5)end = (j + 1) * (seq_length // 5)if end > seq_length:end = seq_lengthdata[i, start:end] = base[j]return dataif __name__ == '__main__':# 设置随机种子以确保可重复性torch.manual_seed(42)np.random.seed(42)# 定义超参数voc_size = 12506hidden_size = 512learning_rate = 0.001batch_size = 32num_epochs = 50seq_length = 50# 初始化模型model = SamOut(voc_size=voc_size, hidden_size=hidden_size,num_heads=8,num_layers=2)# 计算参数量params = sum(p.numel() for p in model.parameters() if p.requires_grad)print(f"模型参数量: {params}")criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 添加学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)# 记录训练过程train_losses = []val_losses = []accuracies = []# 训练循环start_time = time.time()for epoch in range(num_epochs):# 生成训练数据data = generate_sequence_data(batch_size, seq_length, voc_size)input_tensor = data[:, :-1]target_tensor = data[:, 1:]# 前向传播model.train()angles, _ = model(input_tensor)loss = criterion(angles.reshape([-1,voc_size]), target_tensor.reshape(-1))# 反向传播和优化optimizer.zero_grad()loss.backward()# 梯度裁剪防止爆炸torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()# 验证模型val_loss, accuracy = validate_model(model, criterion)# 更新学习率scheduler.step(val_loss)# 记录指标train_losses.append(loss.item())val_losses.append(val_loss)accuracies.append(accuracy)# 打印训练进度if (epoch + 1) % 2 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], 'f'Train Loss: {loss.item():.4f}, 'f'Val Loss: {val_loss:.4f}, 'f'Accuracy: {accuracy:.4f}, 'f'LR: {optimizer.param_groups[0]["lr"]:.6f}')print(f"训练完成,耗时: {time.time() - start_time:.2f}秒")# 绘制训练曲线plt.figure(figsize=(12, 8))plt.subplot(2, 1, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.title('Training and Validation Loss')plt.subplot(2, 1, 2)plt.plot(accuracies, label='Accuracy', color='green')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.title('Validation Accuracy')plt.tight_layout()plt.savefig('training_metrics.png')plt.show()# 测试生成功能test_input = generate_sequence_data(1, 5, voc_size)[0].unsqueeze(0)print("初始输入序列:", test_input)# 生成序列generated_seq = generate_sequence(model, test_input, length=20, temperature=0.8, top_k=10)print("生成的完整序列:", generated_seq[0].tolist())# # 保存模型# torch.save(model.state_dict(), 'gru_sinusoidal_model.pth')# print("模型已保存")# samout 训练结果# 模型参数量: 16481290# Epoch [2/50], Train Loss: 4.3127, Val Loss: 3.3281, Accuracy: 0.6798, LR: 0.001000# Epoch [4/50], Train Loss: 3.1816, Val Loss: 3.0913, Accuracy: 0.6416, LR: 0.001000# Epoch [6/50], Train Loss: 3.1512, Val Loss: 3.2262, Accuracy: 0.6798, LR: 0.001000# Epoch [8/50], Train Loss: 3.0258, Val Loss: 3.0195, Accuracy: 0.6913, LR: 0.001000# Epoch [10/50], Train Loss: 3.0020, Val Loss: 3.0175, Accuracy: 0.7028, LR: 0.001000# Epoch [12/50], Train Loss: 2.9735, Val Loss: 2.9120, Accuracy: 0.6913, LR: 0.001000# Epoch [14/50], Train Loss: 3.0765, Val Loss: 3.0578, Accuracy: 0.7028, LR: 0.001000# Epoch [16/50], Train Loss: 2.9329, Val Loss: 2.9462, Accuracy: 0.7143, LR: 0.000500# Epoch [18/50], Train Loss: 3.1672, Val Loss: 3.0158, Accuracy: 0.7028, LR: 0.000500# Epoch [20/50], Train Loss: 2.9105, Val Loss: 3.0082, Accuracy: 0.6913, LR: 0.000250# Epoch [22/50], Train Loss: 3.1239, Val Loss: 2.9615, Accuracy: 0.7028, LR: 0.000250# Epoch [24/50], Train Loss: 3.0007, Val Loss: 2.9122, Accuracy: 0.7143, LR: 0.000250# Epoch [26/50], Train Loss: 2.8696, Val Loss: 2.9695, Accuracy: 0.6913, LR: 0.000250# Epoch [28/50], Train Loss: 2.9758, Val Loss: 3.0564, Accuracy: 0.6913, LR: 0.000125# Epoch [30/50], Train Loss: 2.9738, Val Loss: 2.9610, Accuracy: 0.6913, LR: 0.000125# Epoch [32/50], Train Loss: 3.0185, Val Loss: 3.0381, Accuracy: 0.6913, LR: 0.000063# Epoch [34/50], Train Loss: 2.9667, Val Loss: 2.9316, Accuracy: 0.7143, LR: 0.000063# Epoch [36/50], Train Loss: 3.0201, Val Loss: 2.9142, Accuracy: 0.7028, LR: 0.000063# Epoch [38/50], Train Loss: 3.0144, Val Loss: 2.9645, Accuracy: 0.7028, LR: 0.000063# Epoch [40/50], Train Loss: 2.9428, Val Loss: 2.9646, Accuracy: 0.6913, LR: 0.000031# Epoch [42/50], Train Loss: 2.8396, Val Loss: 3.0575, Accuracy: 0.6913, LR: 0.000031# Epoch [44/50], Train Loss: 2.8284, Val Loss: 2.9114, Accuracy: 0.7028, LR: 0.000016# Epoch [46/50], Train Loss: 2.9099, Val Loss: 3.1118, Accuracy: 0.6913, LR: 0.000016# Epoch [48/50], Train Loss: 2.9837, Val Loss: 3.0297, Accuracy: 0.6913, LR: 0.000008# Epoch [50/50], Train Loss: 3.1282, Val Loss: 3.0566, Accuracy: 0.6798, LR: 0.000008# 训练完成,耗时: 43.19秒# 初始输入序列: tensor([[0, 1, 2, 3, 4]])# 生成的完整序列: [0, 1, 2, 3, 4, 3557, 6592, 3928, 3625, 8038, 12288, 7408, 5932, 11470, 9788, 9583, 1425, 8498, 344, 3350, 5421, 4488, 8998, 3827, 8188]## 进程已结束,退出代码为 0

transformer

import math
import timeimport numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.nn import TransformerEncoder, TransformerEncoderLayerclass TransformerSamOut(nn.Module):"""基于Transformer的模型架构"""def __init__(self, voc_size, hidden_size, nhead=8, nlayers=6):super(TransformerSamOut, self).__init__()self.voc_size = voc_sizeself.hidden_size = hidden_sizeself.embedding = nn.Embedding(voc_size, hidden_size, padding_idx=0)# 位置编码self.pos_encoder = PositionalEncoding(hidden_size)# Transformer编码器层encoder_layers = TransformerEncoderLayer(hidden_size, nhead, dim_feedforward=hidden_size * 2, batch_first=True)self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)# 角度预测层self.angle_predictor = nn.Linear(hidden_size, voc_size, bias=False)# 初始化权重self.init_weights()def init_weights(self):initrange = 0.1self.embedding.weight.data.uniform_(-initrange, initrange)self.angle_predictor.weight.data.uniform_(-initrange, initrange)def forward(self, x, state=None):# 嵌入层x = self.embedding(x) * math.sqrt(self.hidden_size)# 添加位置编码x = self.pos_encoder(x)# Transformer编码x = self.transformer_encoder(x)# 角度预测 (限制在[0, π]范围内)angles = self.angle_predictor(x)angles = torch.sigmoid(angles) * math.pi  # 映射到[0, π]return angles, Noneclass PositionalEncoding(nn.Module):"""位置编码层"""def __init__(self, d_model, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=0.1)# 计算位置编码position = torch.arange(max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))pe = torch.zeros(max_len, d_model)pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)def forward(self, x):x = x + self.pe[:, :x.size(1), :]return self.dropout(x)def predict_next_token(model, input_sequence, temperature=1.0, top_k=50):"""带温度调节的预测下一个token"""model.eval()with torch.no_grad():angles, _ = model(input_sequence)last_angles = angles[:, -1, :]# 计算正弦值并应用温度调节sin_values = torch.sin(last_angles) / temperature# 应用softmax得到概率分布probs = torch.softmax(sin_values, dim=-1)# Top-k采样if top_k > 0:top_probs, top_indices = probs.topk(top_k, dim=-1)# 重新归一化top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)# 采样next_token = torch.multinomial(top_probs, num_samples=1)next_token = top_indices.gather(1, next_token)else:# 直接采样next_token = torch.multinomial(probs, num_samples=1)return next_token.squeeze(1)def generate_sequence(model, start_tokens, length=20, temperature=1.0, top_k=50):"""生成完整序列"""model.eval()with torch.no_grad():generated = start_tokens.clone()state = None  # 初始化隐藏状态for _ in range(length):input_seq = generated[:, -1:]  # 只使用最后一个token作为输入angles, state = model(input_seq, state)# 计算正弦值并应用温度调节sin_values = torch.sin(angles[:, -1, :]) / temperature# 应用softmax得到概率分布probs = torch.softmax(sin_values, dim=-1)# Top-k采样if top_k > 0:top_probs, top_indices = probs.topk(top_k, dim=-1)top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)next_token = torch.multinomial(top_probs, num_samples=1)next_token = top_indices.gather(1, next_token)else:next_token = torch.multinomial(probs, num_samples=1)generated = torch.cat((generated, next_token), dim=1)return generated#
# 修改 validate_model 函数
def validate_model(model, criterion, batch_size=16, seq_length=50, voc_size=12506):"""验证模型性能"""model.eval()with torch.no_grad():data = generate_sequence_data(batch_size, seq_length, voc_size)input_tensor = data[:, :-1]target_tensor = data[:, 1:]angles, _ = model(input_tensor)# 修复维度问题:将输出和目标重塑为交叉熵损失所需的形状loss = criterion(angles.reshape(-1, voc_size), target_tensor.reshape(-1))# 计算准确率# 使用 logits 而不是 sin_values 进行预测predictions = torch.argmax(angles, dim=-1)correct = (predictions == target_tensor).sum().item()total = target_tensor.numel()accuracy = correct / totalreturn loss.item(), accuracydef generate_sequence_data(batch_size, seq_length, voc_size):"""生成有意义的序列数据"""data = torch.zeros(batch_size, seq_length, dtype=torch.long)# 创建多种序列模式for i in range(batch_size):# 模式1: 递增序列if i % 3 == 0:data[i] = torch.arange(seq_length) % voc_size# 模式2: 重复序列elif i % 3 == 1:pattern = torch.tensor([0, 1, 2, 3, 4])repeats = (seq_length // len(pattern)) + 1data[i] = pattern.repeat(repeats)[:seq_length]# 模式3: 随机但有局部模式else:base = torch.randint(0, voc_size, (seq_length // 5,))for j in range(5):start = j * (seq_length // 5)end = (j + 1) * (seq_length // 5)if end > seq_length:end = seq_lengthdata[i, start:end] = base[j]return dataif __name__ == '__main__':# 设置随机种子以确保可重复性torch.manual_seed(42)np.random.seed(42)# 定义超参数voc_size = 12506hidden_size = 512learning_rate = 0.001batch_size = 32num_epochs = 50seq_length = 50# 初始化模型model = TransformerSamOut(voc_size=voc_size, hidden_size=hidden_size, nhead=8, nlayers=2)# 计算参数量params = sum(p.numel() for p in model.parameters() if p.requires_grad)print(f"模型参数量: {params}")criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 添加学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)# 记录训练过程train_losses = []val_losses = []accuracies = []# 训练循环start_time = time.time()for epoch in range(num_epochs):# 生成训练数据data = generate_sequence_data(batch_size, seq_length, voc_size)input_tensor = data[:, :-1]target_tensor = data[:, 1:]# 前向传播model.train()angles, _ = model(input_tensor)loss = criterion(angles.reshape([-1,voc_size]), target_tensor.reshape(-1))# 反向传播和优化optimizer.zero_grad()loss.backward()# 梯度裁剪防止爆炸torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()# 验证模型val_loss, accuracy = validate_model(model, criterion)# 更新学习率scheduler.step(val_loss)# 记录指标train_losses.append(loss.item())val_losses.append(val_loss)accuracies.append(accuracy)# 打印训练进度if (epoch + 1) % 2 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], 'f'Train Loss: {loss.item():.4f}, 'f'Val Loss: {val_loss:.4f}, 'f'Accuracy: {accuracy:.4f}, 'f'LR: {optimizer.param_groups[0]["lr"]:.6f}')print(f"训练完成,耗时: {time.time() - start_time:.2f}秒")# 绘制训练曲线plt.figure(figsize=(12, 8))plt.subplot(2, 1, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.title('Training and Validation Loss')plt.subplot(2, 1, 2)plt.plot(accuracies, label='Accuracy', color='green')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.title('Validation Accuracy')plt.tight_layout()plt.savefig('training_metrics1.png')plt.show()# 测试生成功能test_input = generate_sequence_data(1, 5, voc_size)[0].unsqueeze(0)print("初始输入序列:", test_input)# 生成序列generated_seq = generate_sequence(model, test_input, length=20, temperature=0.8, top_k=10)print("生成的完整序列:", generated_seq[0].tolist())# 保存模型# torch.save(model.state_dict(), 'gru_sinusoidal_model.pth')# print("模型已保存")# transformer 训练结果# 模型参数量: 17011712# Epoch [2/50], Train Loss: 8.7828, Val Loss: 8.4675, Accuracy: 0.2857, LR: 0.001000# Epoch [4/50], Train Loss: 8.3525, Val Loss: 8.2212, Accuracy: 0.2219, LR: 0.001000# Epoch [6/50], Train Loss: 8.1063, Val Loss: 7.8811, Accuracy: 0.2143, LR: 0.001000# Epoch [8/50], Train Loss: 7.8079, Val Loss: 7.6957, Accuracy: 0.0867, LR: 0.001000# Epoch [10/50], Train Loss: 7.5909, Val Loss: 7.5769, Accuracy: 0.0714, LR: 0.001000# Epoch [12/50], Train Loss: 7.5184, Val Loss: 7.5046, Accuracy: 0.0714, LR: 0.001000# Epoch [14/50], Train Loss: 7.4894, Val Loss: 7.4595, Accuracy: 0.0867, LR: 0.001000# Epoch [16/50], Train Loss: 7.4531, Val Loss: 7.4327, Accuracy: 0.1250, LR: 0.001000# Epoch [18/50], Train Loss: 7.4244, Val Loss: 7.4121, Accuracy: 0.2296, LR: 0.001000# Epoch [20/50], Train Loss: 7.4023, Val Loss: 7.3907, Accuracy: 0.2679, LR: 0.001000# Epoch [22/50], Train Loss: 7.3842, Val Loss: 7.3815, Accuracy: 0.4133, LR: 0.001000# Epoch [24/50], Train Loss: 7.3851, Val Loss: 7.3528, Accuracy: 0.4668, LR: 0.001000# Epoch [26/50], Train Loss: 7.3554, Val Loss: 7.3714, Accuracy: 0.4362, LR: 0.001000# Epoch [28/50], Train Loss: 7.3618, Val Loss: 7.3531, Accuracy: 0.3444, LR: 0.000500# Epoch [30/50], Train Loss: 7.3656, Val Loss: 7.3212, Accuracy: 0.3138, LR: 0.000500# Epoch [32/50], Train Loss: 7.3372, Val Loss: 7.3559, Accuracy: 0.3214, LR: 0.000500# Epoch [34/50], Train Loss: 7.3486, Val Loss: 7.3471, Accuracy: 0.3750, LR: 0.000250# Epoch [36/50], Train Loss: 7.3454, Val Loss: 7.3453, Accuracy: 0.4273, LR: 0.000250# Epoch [38/50], Train Loss: 7.3494, Val Loss: 7.3443, Accuracy: 0.4949, LR: 0.000125# Epoch [40/50], Train Loss: 7.3139, Val Loss: 7.3464, Accuracy: 0.5230, LR: 0.000125# Epoch [42/50], Train Loss: 7.3496, Val Loss: 7.3431, Accuracy: 0.5230, LR: 0.000063# Epoch [44/50], Train Loss: 7.3441, Val Loss: 7.3441, Accuracy: 0.5230, LR: 0.000063# Epoch [46/50], Train Loss: 7.3255, Val Loss: 7.3122, Accuracy: 0.5153, LR: 0.000063# Epoch [48/50], Train Loss: 7.3381, Val Loss: 7.3380, Accuracy: 0.5153, LR: 0.000063# Epoch [50/50], Train Loss: 7.3249, Val Loss: 7.3147, Accuracy: 0.5000, LR: 0.000031# 训练完成,耗时: 47.64秒# 初始输入序列: tensor([[0, 1, 2, 3, 4]])# 生成的完整序列: [0, 1, 2, 3, 4, 13, 23, 26, 43, 32, 32, 32, 40, 38, 8, 11, 44, 13, 40, 2251, 2251, 6983, 11, 44, 40]## 进程已结束,退出代码为 0
http://www.xdnf.cn/news/18539.html

相关文章:

  • DINOv3 重磅发布
  • CLruCache::BucketFromIdentifier函数分析
  • k8s集群限制不同用户操作
  • 基于springboot的中医养生管理系统
  • 机器学习-聚类算法
  • 【算法精练】 哈夫曼编码
  • Kotlin-基础语法练习二
  • 【python】python测试用例模板
  • 深入理解Java虚拟机:JVM高级特性与最佳实践(第3版)第二章知识点问答(21题)
  • 效果驱动复购!健永科技RFID牛场智能称重项目落地
  • AI资深 Java 研发专家系统解析Java 中常见的 Queue实现类
  • 手机惊魂
  • MySQL高可用之MHA
  • 【智慧城市】2025年中国地质大学(武汉)暑期实训优秀作品(1):智绘旅程构建文旅新基建
  • 稀土元素带来农业科技革命
  • 哈尔滨服务器托管,如何实现高效稳定运行?
  • OBCP第四章 OceanBase SQL 调优学习笔记:通俗解读与实践指南
  • comfyUI背后的一些技术——Checkpoints
  • React:Umi + React + Ant Design Pro的基础上接入Mock数据
  • Unity编辑器相关
  • 基于STM32设计的大棚育苗管理系统(4G+华为云IOT)_265
  • RabbitMQ:技巧汇总
  • 如何用 SolveigMM Video Splitter 从视频中提取 AAC 音频
  • leetcode_238 除自身以外的数组乘积
  • 实践题:智能客服机器人设计
  • 【Dify(v1.x) 核心源码深入解析】prompt 模块
  • centos下安装Nginx(搭建高可用集群)
  • 利用随机森林筛查 “癌症点”
  • yggjs_react使用教程 v0.1.1
  • Excel中运行VB的函数