当前位置: 首页 > ds >正文

李沐-第十章-训练Seq2SeqAttentionDecoder报错

问题

系统: win11
显卡:5060
CUDA:12.8
pytorch:2.7.1+cu128
pycharm:2025.2

训练带有注意力机制的编码器-解码器网络时,

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)net = EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

会报错unpack错误.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[8], line 118 decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)10 net = d2l.EncoderDecoder(encoder, decoder)
---> 11 train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)Cell In[7], line 41, in train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device)39 bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)40 dec_input = torch.cat([bos, Y[:,:-1]], 1)
---> 41 Y_hat, _ = net(X, dec_input, X_valid_len)42 l = loss(Y_hat, Y, Y_valid_len)43 l.sum().backward()ValueError: too many values to unpack (expected 2)

分析

排查发现d2l包里面的EncoderDecoder定义和教材中不同.
d2l包里面的定义:

class EncoderDecoder(d2l.Classifier):"""The base class for the encoder--decoder architecture.Defined in :numref:`sec_encoder-decoder`"""def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):enc_all_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_all_outputs, *args)# Return decoder output onlyreturn self.decoder(dec_X, dec_state)[0]def predict_step(self, batch, device, num_steps,save_attention_weights=False):"""Defined in :numref:`sec_seq2seq_training`"""batch = [d2l.to(a, device) for a in batch]src, tgt, src_valid_len, _ = batchenc_all_outputs = self.encoder(src, src_valid_len)dec_state = self.decoder.init_state(enc_all_outputs, src_valid_len)outputs, attention_weights = [d2l.expand_dims(tgt[:, 0], 1), ], []for _ in range(num_steps):Y, dec_state = self.decoder(outputs[-1], dec_state)outputs.append(d2l.argmax(Y, 2))# Save attention weights (to be covered later)if save_attention_weights:attention_weights.append(self.decoder.attention_weights)return d2l.concat(outputs[1:], 1), attention_weights

教材的定义(9.6.3):

class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, **kwargs):super(EncoderDecoder, self).__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)

解决方法

使用教材中的EncoderDecoder定义, 正常训练不报错.
训练结果:
在这里插入图片描述

http://www.xdnf.cn/news/18946.html

相关文章:

  • Leetcode top100之链表排序
  • 【ElasticSearch】json查询语法
  • 美团一面“保持好奇”
  • Spring Boot 项目打包成可执行程序
  • HTML应用指南:利用POST请求获取全国三星门店位置信息
  • Ubuntu安装及配置Git(Ubuntu install and config Git Tools)
  • Next.js 15.5.0:探索 Turbopack Beta、稳定的 Node.js 中间件和 TypeScript 的改进
  • 30.throw抛异常
  • 【图像算法 - 23】工业应用:基于深度学习YOLO12与OpenCV的仪器仪表智能识别系统
  • 【P2P】P2P主要技术及RELAY服务1:python实现
  • Kubernetes 构建高可用、高性能 Redis 集群
  • 线性回归入门:从原理到实战的完整指南
  • k8sday17安全机制
  • 真实应急响应案例记录
  • 一键终结Win更新烦恼!你从未见过如此强大的更新暂停工具!
  • PNP机器人介绍:全球知名具身智能/AI机器人实验室介绍之多伦多大学机器人研究所
  • PC端逆向会用到的常见伪指令
  • 解读 “货位清则标识明,标识明则管理成” 的实践价值
  • 灰狼算法+四模型对比!GWO-CNN-BiLSTM-Attention系列四模型多变量时序预测
  • EasyClick 生成唯一设备码
  • 【CV】图像基本操作——①图像的IO操作
  • XC95144XL-10TQG144I Xilinx XC9500XL 高性能 CPLD
  • 从0到1:用 Qwen3-Coder 和 高德MCP 助力数字文旅建造——国庆山西游
  • 我的小灶坑
  • Web程序设计
  • 《 nmcli网络管理学习》
  • 28 FlashAttention
  • sudo 升级
  • 牛客周赛 Round 106(小苯的方格覆盖/小苯的数字折叠/ 小苯的波浪加密器/小苯的数字变换/小苯的洞数组构造/ 小苯的数组计数)
  • “华生科技杯”2025年全国青少年龙舟锦标赛在海宁举行