为pytorch前向和反向的Tensor生成描述性统计

为pytorch前向和反向的Tensor生成描述性统计

  • 代码

在调试Megatron-DeepSpeed的精度时,我们希望对比每一层前向和反向传播的输入输出误差。然而,由于数据量过大,直接保存所有数据不太现实。因此,我们生成了输入输出tensor的描述性统计信息,并等间隔抽样N个数据点,以比较这些点的相对误差,从而查找精度异常的位置。为了准确定位,我们通过类名和对象ID生成唯一的对象名称(形式为[类名-创建的第几个])以及前向和反向传播的次数。通过保存上述信息,我们可以详细记录并回溯当时的实际输入输出数据。

代码

cat > linear_test.py <<-'EOF'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from datetime import datetime# 设置设备
device = "cpu"if torch.cuda.is_available():device = "cuda:4"def is_tensor(val):# 判断是否为tensor或Parameterreturn isinstance(val, (torch.Tensor, nn.Parameter))def describe_tensor(tensor):# 返回tensor的描述,包括形状和部分数据统计信息shape = list(tensor.shape)tensor_data = tensor.cpu().float().detach().numpy().ravel()num_points = min(16, len(tensor_data))indices = np.linspace(0, len(tensor_data) - 1, num_points, dtype=int)stats = [np.max(tensor_data), np.min(tensor_data), np.mean(tensor_data), np.std(tensor_data)]sample_data = tensor_data[indices]stats_str = ",".join(f"{x:.5f}" for x in stats)sample_str = ",".join(f"{x:.5f}" for x in sample_data)return f"{shape}-{stats_str},{sample_str}"def generate_random_data(shape):# 生成符合指定形状的随机数据max_val, min_val, mean, std = 0.04025, -0.04651, 0.0, 0.00134data = np.random.normal(mean, std, shape)data = (data - data.min()) / (data.max() - data.min()) * (max_val - min_val) + min_valreturn dataindex_counter = 0def log_tensor_data(name, tensor):# 打印tensor的日志数据global index_counterindex_counter += 1timestamp = datetime.now().strftime("%H%M%S%f")if is_tensor(tensor):print(f"{timestamp},{index_counter},{name},0,{describe_tensor(tensor)}")elif isinstance(tensor, (tuple, list)):for idx, t in enumerate(tensor):if is_tensor(t):print(f"{timestamp},{index_counter},{name},{idx},{describe_tensor(t)}")def log_gradient(model):# 打印模型参数梯度信息for name, param in model.named_parameters():if param.grad is not None:log_tensor_data(f"grad-{name}", param.grad)# 对象和类名缓存
object_cache = {}
class_name_count = {}def get_unique_name(class_name, obj_id):# 生成唯一的对象名称if class_name not in class_name_count:class_name_count[class_name] = 0uid = f"{class_name}_{obj_id}"if uid not in object_cache:class_name_count[class_name] += 1object_cache[uid] = {"idx": class_name_count[class_name]}return f'{class_name}-{object_cache[uid]["idx"]}'def initialize_module_attributes(module):# 初始化模块属性if not hasattr(module, 'uuid'):module.uuid = get_unique_name(module.__class__.__name__, id(module))if not hasattr(module, 'backward_step'):module.backward_step = 0if not hasattr(module, 'forward_step'):module.forward_step = 0def forward_decorator():# 包装forward函数的修饰器def decorator(func):def wrapped(*args, **kwargs):module = args[0]initialize_module_attributes(module)module.forward_step += 1log_tensor_data(f"forward-{module.uuid}-{module.forward_step}-input", args)output = func(*args, **kwargs)log_tensor_data(f"forward-{module.uuid}-{module.forward_step}-output", output)return outputreturn wrappedreturn decoratordef pre_backward_hook(module, grad_input):# 反向传播前的钩子函数initialize_module_attributes(module)module.backward_step += 1log_tensor_data(f"backward-{module.uuid}-{module.backward_step}-input", grad_input)def post_backward_hook(module, grad_input, grad_output):# 反向传播后的钩子函数initialize_module_attributes(module)log_tensor_data(f"backward-{module.uuid}-{module.backward_step}-output", grad_output)def register_backward_hooks(module):# 注册反向传播钩子module.register_full_backward_pre_hook(pre_backward_hook)module.register_full_backward_hook(post_backward_hook)class CustomLinear(nn.Module):def __init__(self, shape):super(CustomLinear, self).__init__()weight_data = torch.from_numpy(generate_random_data(shape)).half().to(device)self.weight = nn.Parameter(weight_data)self.register_parameter('bias', None)register_backward_hooks(self)@forward_decorator()def forward(self, input_):return F.linear(input_, self.weight, self.bias)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layer1 = CustomLinear((5504, 4096))self.layer2 = CustomLinear((4096, 5504))@forward_decorator()def forward(self, input_):out = self.layer1(input_)out = self.layer2(out)return out
# 设置随机种子
np.random.seed(1)
torch.manual_seed(2)# 创建和训练模型
model = MyModel().half().to(device)
model.train()input_data = torch.from_numpy(generate_random_data((1024, 12, 4096))).half().to(device)
target_data = torch.from_numpy(generate_random_data((1024, 12, 4096))).half().to(device)for _ in range(2):outputs = model(input_data)outputs.backward(target_data)  # 使用全一的梯度来反向传播log_gradient(model)
EOF
python3 linear_test.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1424493.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

AquaCrop模型运行及结果分析、代码解析;气象、土壤、作物和管理措施等数据的准备和输入;农业水资源管理

目录 专题一 模型原理与数据要求 专题二 模型数据准备 专题三 模型运行及结果分析 专题四 参数分析 专题五 源代码分析 更多应用 AquaCrop是由世界粮食及农业组织&#xff08;FAO&#xff09;开发的一个先进模型&#xff0c;旨在研究和优化农作物的水分生产效率。这个模型…

海外住宅IP适用场景

海外住宅IP是指海外互联网服务提供商分配给海外家庭的IP地址&#xff0c;通常长期部署在特定区域的公网中&#xff0c;网络连接稳定&#xff0c;能够快速浏览指定网页内容&#xff0c;请求响应成功。它的适用场景十分广泛&#xff0c;接下来将详细介绍哪些领域会使用海外住宅IP…

错误: 找不到或无法加载主类问题(已解决)

今天在虚拟机中安装了idea2023.2的版本&#xff0c;运行代码时发现错误找不到主类&#xff01; 直接说结论&#xff1a; 我先clean了一下target&#xff0c;然后重新build&#xff0c;发现maven报错了&#xff0c;idea2023.2默认使用了内置的maven&#xff0c;然后我切换了一下…

垃圾分类管理系统java项目

文章目录 垃圾分类管理系统一、项目演示二、项目介绍三、系统部分功能截图四、部分代码展示五、底部获取项目&#xff08;9.9&#xffe5;带走&#xff09; 垃圾分类管理系统 一、项目演示 垃圾分类管理系统 二、项目介绍 系统角色&#xff1a;管理员、用户 1、登录、注册功能…

单位个人如何向期刊投稿发表文章?

在单位担任信息宣传员一职以来,我深感肩上的责任重大。每月的对外信息宣传投稿不仅是工作的核心,更是衡量我们部门成效的重要指标。起初,我满腔热血,以为只要勤勉努力,将精心撰写的稿件投至各大报社、报纸期刊的官方邮箱,就能顺利登上版面,赢得读者的青睐。然而,现实远比理想骨…

谷歌全力反击 OpenAI:Google I/O 2024 揭晓 AI 新篇章,一场激动人心的技术盛宴

&#x1f680; 谷歌全力反击 OpenAI&#xff1a;Google I/O 2024 揭晓 AI 新篇章&#xff0c;一场激动人心的技术盛宴&#xff01; 在这个人工智能的全新时代&#xff0c;只有谷歌能让你眼前一亮&#xff01;来自全球瞩目的 Google I/O 2024 开发者大会&#xff0c;谷歌用一场…

54.指针

目录 一.什么是指针&#xff1f; 二&#xff0e;定义一个指针变量 三&#xff0e;指针变量类型 四&#xff0e;取地址运算符& 五.取值运算符* 六.视频教程 一.什么是指针&#xff1f; 口语中的指针一般指指针变量&#xff0c;指针变量存放的是一个地址。普通变量存放…

Nacos+GateWay 搭建微服务架构

文章目录 1.当前项目架构分析1.请求多个模块的方式1.请求renren-fast模块开发环境生产环境 2.请求sunliving-commodity模块1.使用环境变量资源路径的方式2.开发环境 dev.env.js3.生产环境 prod.env.js 3.文件上传请求 sunliving-service模块1.请求后端接口&#xff08;开发环境…

上班族兼职新篇章:10大实战攻略,轻松年赚1-20万

对于众多上班族而言&#xff0c;如何在工作之余赚取额外收入&#xff0c;开启自己的第一份副业&#xff0c;已成为许多人心中的疑问。每个人的才能和兴趣点不尽相同&#xff0c;但都有机会找到适合自己的兼职方式。接下来&#xff0c;就让我们一起探索这10大实战攻略&#xff0…

那些年我与c++的叫板(一)--string类自实现

引子&#xff1a;我们学习了c中的string类&#xff0c;那我们能不能像以前数据结构一样自己实现string类呢&#xff1f;以下是cplusplus下的string类&#xff0c;我们参考参考&#xff01; 废话不多说&#xff0c;直接代码实现&#xff1a;&#xff08;注意函数之间的复用&…

20232803 2023-2024-2 《网络攻防实践》实践九报告

目录 1.实践内容2.实践过程2.1 手工修改可执行文件&#xff0c;改变程序执行流程&#xff0c;直接跳转到getShell函数2.2 利用foo函数的Bof漏洞&#xff0c;构造一个攻击输入字符串&#xff0c;覆盖返回地址&#xff0c;触发getShell函数2.3 注入一个自己制作的shellcode并运行…

从独立开发者到成为SeaTunnel社区的贡献者,我做对了哪些事儿?

个人介绍 大家好&#xff0c;我是闫成雨&#xff0c;目前是一名独立开发者。专注于数据开发、机器学习、资源调度算法和分布式系统。 GitHub ID: CheneyYin 个人主页&#xff1a;https://cheneyyin.github.io/ 为社区做了哪些贡献 加强了Spark引擎和Flink引擎对SeaTunnel数据…

链接表存储图(C++注释详解): 构建表 深度优先遍历 (DFS)

链接表的结构体单元: #define size 100 typedef struct node {int idx;//下一个节点的索引int wt;//权重, 也可根据实际情景存储边的信息struct node* next; }Node; Node* hd[size]; // 存储图的邻接表 链接表的的构建: int main() {int n, m;cin >> n >> m; //…

script标签以及defer和async属性

1. <script>标签 将JavaScript代码嵌入到HTML中主要方式是使用<script>元素。 使用<script>的方式有两种&#xff1a; &#xff08;1&#xff09;直接在网页中嵌入JavaScript代码&#xff1a; <script>function sayHi() {console.log("Hi"…

slugify,slug格式转换工具

目录 前言 安装 特性 基本功能 生成简单的Slug 处理特殊字符 Unicode支持 高级功能 自定义替换规则 过滤停用词 使用不同的分隔符 处理多种语言 实际应用场景 网站和博客的SEO优化 电子商务平台的产品链接 数据清洗和预处理 总结 前言 在Web开发中&#xff0c;生成易于…

在另外一个页面,让另外一个页面弹框显示操作(调佣公共的弹框)vue

大概意思是&#xff0c;登录弹框在另外一个页面中&#xff0c;而当前页面不存在&#xff0c;在当前页面中判断如果token不存在&#xff0c;就弹框出登录的弹框 最后一行 window.location.href … 如果当前用户已登录&#xff0c;则执行后续操作(注意此处&#xff0c;可不要)

汇聚荣:拼多多长期没有流量如何提高?

在电商的海洋中&#xff0c;拼多多以其独特的团购模式吸引了众多消费者的目光。然而&#xff0c;随着市场竞争的加剧和消费者需求的多样化&#xff0c;一些商家发现自家店铺的流量持续低迷&#xff0c;销售业绩难以突破。面对这样的挑战&#xff0c;如何有效提升拼多多店铺的客…

shell脚本之sort,uniq,tr,cut,sphit,paste,ecal与正则表达式

sort命令 uniq命令 tr命令 cut命令 sphit命令 paste命令 ecal命令 正则表达式 sort命令 sort命令---以行为单位对文件内容进行排序&#xff0c;也可以根据不同的数据类型来排序 比较原则是从首字符向后&#xff0c;依次按ASCII码值进行比较&#xff0c;最后将他们按升序…

思科模拟器--2.静态路由和默认路由配置24.5.15

首先&#xff0c;创建三个路由器和两个个人电脑。 接着&#xff0c;配置两台电脑的IP&#xff0c;子网掩码和默认网关 对Router 0&#xff0c;进行以下命令&#xff1a; 对Router进行以下命令&#xff1a; 对Router2进行以下命令&#xff1a; 本实验完成。 验证&#xff1a;PC…