当前位置: 首页 > backend >正文

记录一个大模型逐层微调计算损失输出少了一个维度的小bug

1.假如针对的对象是linear

def _compute_mse_on_batch(layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:inps_batch, outs_batch = next(batch_iter)print("Initial inps_batch:", inps_batch.shape)print("Initial outs_batch:", outs_batch.shape)# print("Any NaNs in inps_batch:", torch.isnan(inps_batch).any())# print("Any NaNs in outs_batch:", torch.isnan(outs_batch).any())# if inps_batch.shape[0] != 1:#     for name, value in list(kwargs.items()):#         if isinstance(value, torch.Tensor) and value.shape[0] == 1:#             if name not in ("attention_mask", "position_ids"):#                 warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")#             repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]#             kwargs[name] = value.tile(*repeats)outs_prediction= layer(inps_batch, **kwargs)assert outs_prediction.shape == outs_batch.shapeloss = F.mse_loss(outs_prediction, outs_batch)# print("Computed loss:", loss.item())return loss

2.假如针对的对象是transformer

def _compute_mse_on_batch(layer: nn.Module, batch_iter: Iterator[Tuple[torch.Tensor, torch.Tensor]], **kwargs
) -> torch.Tensor:inps_batch, outs_batch = next(batch_iter)print("Initial inps_batch:", inps_batch.shape)print("Initial outs_batch:", outs_batch.shape)# print("Any NaNs in inps_batch:", torch.isnan(inps_batch).any())# print("Any NaNs in outs_batch:", torch.isnan(outs_batch).any())# if inps_batch.shape[0] != 1:#     for name, value in list(kwargs.items()):#         if isinstance(value, torch.Tensor) and value.shape[0] == 1:#             if name not in ("attention_mask", "position_ids"):#                 warnings.warn(f"Tiling an unexpected kwarg {name} over batch size; make sure this is valid.")#             repeats = [len(inps_batch)] + [1 for _ in range(value.ndim - 1)]#             kwargs[name] = value.tile(*repeats)outs_prediction, *_unused = layer(inps_batch, **kwargs)# print("outs_prediction device in loss:", outs_prediction.device)# print(" outs_batch device in loss:",  outs_batch.device)assert outs_prediction.shape == outs_batch.shapeloss = F.mse_loss(outs_prediction, outs_batch)# print("Computed loss:", loss.item())return loss

值得注意的是,假如我们在线性层里面写的是: outs_prediction, *_unused = layer(inps_batch, **kwargs),由于线性层返回没有 *_unused ,会导致我们输入[batchsize,input]时候得到的输出不是我们期望的[batchsize,output],而只会有[output]

http://www.xdnf.cn/news/13746.html

相关文章:

  • Three.js搭建小米SU7三维汽车实战(4)场景搭建
  • 【时时三省】(C语言基础)将外部变量的作用域扩展到其他文件
  • 计算复变积分 $w = \int_0^1 (1 + it)^2 \, dt$
  • 【清晰教程】可视化数据集标注工具Labelimg零基础安装
  • openstack实例创建过程分析
  • 深度掌控,智启未来 —— 基于 STM32F103RBT6 的控制板
  • 离线部署openstack 2024.1 cinder
  • pangolin
  • 全连接层和卷积层等效情况举例
  • 离线部署openstack 2024.1控制节点keystone
  • Design Compiler:使用read_file命令读取RTL设计
  • Python Day 48 学习(日志Day18学习)
  • 谷歌被禁用的麦克风如何能使用
  • 榕壹云打车系统:赋能出租与网约车的全场景解决方案
  • 阿里1688 普通 231滑块 x82 分析
  • 前端将多个PDF链接的内容拼接成一个后返回出一个链接进行打开
  • 一起学习swin-transformer(一)
  • STM32开发GCC常用编译选项
  • 计组刷题日记(1)
  • 快速熟悉公司的服务器开发环境需要系统
  • 软件测试之APP测试要点(包含Monkey基础使用)
  • 如何创建vue工程?以及遇到问题的解决方法
  • vue3提供的hook和通常的函数有什么区别
  • Python直接访问Windows API库之pywin32使用详解
  • mysql递归查询所有父节点拼接父节点名称
  • 使用Gradle打包springboot项目为JAR包教程
  • 蓝凌EKP产品:低门槛、可扩展、可视化公式引擎应用示例
  • 功能化组件编码流程-2(延续上一章)
  • 《HarmonyOSNext属性动画实战手册:让UI丝滑起舞的魔法指南》
  • 人工智能新范式:从大模型到智能体的演化逻辑