当前位置: 首页 > web >正文

深度学习pycharm debug

深度学习中,Debug 是定位并解决代码逻辑错误(如张量维度不匹配)、训练异常(如 Loss 波动)、数据问题(如标签错误)的关键手段,通过打印维度、可视化梯度等方法确保模型正常运行、优化性能,贯穿开发全流程。

直接上实例以经典错误shape报错为例:

import torch
import torch.nn as nn
import torch.nn.functional as F# 模拟图像数据
x = torch.randn(8, 3, 64, 64)  # [B, C, H, W],batch size = 8# 模拟标签(分类任务)
labels = torch.randint(0, 5, (8,))  # 5 类问题,标签是 [8]# 模型定义
class BuggyNet(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.pool = nn.AdaptiveAvgPool2d((4, 4))  # 变成 [B, 32, 4, 4]self.linear = nn.Linear(32, 5)  # ❌ 故意设置错误 in_featuresdef forward(self, x):x = F.relu(self.conv1(x))        # [B, 16, 64, 64]x = F.relu(self.conv2(x))        # [B, 32, 64, 64]x = self.pool(x)                 # [B, 32, 4, 4]x = self.linear(x)               # ❌ 错误! x 是 4D,Linear 接受 2D 或 3Dreturn xmodel = BuggyNet()
criterion = nn.CrossEntropyLoss()# 前向传播
outputs = model(x)                  # 会报错
loss = criterion(outputs, labels)  # 不会执行到这里

首先设置断点:

然后进行debug右击:

 然后会出现控制台:

会出现变量和变量的信息(shape,值):

然后我们进行单步:

然后变量开始变化,当单步到24行时:

此刻x的shape是(8,32,4,4)但是在这个linear层

self.linear = nn.Linear(32, 5)  # ❌ 故意设置错误 in_features

期望输入是32,不仅维度不相同channel也不相同,所以继续单步会报错:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x4 and 32x5)

然后我们根据错误进行操作将x展平并且修改linear的输入:
x = x.view(x.size(0), -1)        # [8, 32*4*4] = [8, 512]
self.linear = nn.Linear(512, 5)  # ✅ 修复后的定义

import torch
import torch.nn as nn
import torch.nn.functional as F# 模拟图像数据
x = torch.randn(8, 3, 64, 64)  # [B, C, H, W],batch size = 8# 模拟标签(分类任务)
labels = torch.randint(0, 5, (8,))  # 5 类问题,标签是 [8]# 模型定义
class BuggyNet(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.pool = nn.AdaptiveAvgPool2d((4, 4))  # 变成 [B, 32, 4, 4]self.linear = nn.Linear(512, 5)  # 此处修改def forward(self, x):x = F.relu(self.conv1(x))        # [B, 16, 64, 64]x = F.relu(self.conv2(x))        # [B, 32, 64, 64]x = self.pool(x)                 # [B, 32, 4, 4]x = x.view(x.size(0), -1)        # 此处修改x = self.linear(x)               return xmodel = BuggyNet()
criterion = nn.CrossEntropyLoss()# 前向传播
outputs = model(x)                  
loss = criterion(outputs, labels)  

然后我们这样就不会报错了。

很多时候缝合模块时就是经常遇见shape问题,耐性一点关注输入输出shape这样就可以轻松解决问题。

http://www.xdnf.cn/news/10703.html

相关文章:

  • Cesium 自带的标注碰撞检测实现标注避让
  • esp32关于PWM最清晰的解释
  • 渊龙靶场-sql注入(数字型注入)
  • 快乐大冒险:解锁身体里的 “快乐密码”
  • 力扣刷题Day 68:搜索插入位置(35)
  • 如何在 Windows 11 24H2 的任务栏时钟中显示秒数
  • js的时间循环的讲解
  • 100V离线语音通断器
  • java笔记08
  • 15-2021剑侠情缘2-各种修复完善+虚拟机单机端+外网服务端整理+文本教程+视频教程
  • Linux服务器安装GUI界面工具
  • 【数据集】NCAR CESM Global Bias-Corrected CMIP5 Output to Support WRF/MPAS Research
  • Redis部署架构详解:原理、场景与最佳实践
  • Java函数式编程(中)
  • 第十二节:第五部分:集合框架:Set集合的特点、底层原理、哈希表、去重复原理
  • 《QDebug 2025年5月》
  • 基于大模型的急性乳腺炎全病程风险预测与综合治疗方案研究
  • Playwright Python API 测试:从入门到实践
  • 滑动窗口 -- 灵神刷题
  • C# 异常处理进阶:精准获取错误行号的通用方案
  • ubuntu安装devkitPro
  • 什么算得到?什么又算失去?
  • ps曝光度调整
  • 继承(全)
  • 2024年数维杯国际大学生数学建模挑战赛D题城市弹性与可持续发展能力评价解题全过程论文及程序
  • YOLOv10改进|爆改模型|涨点|C2F引入空间和通道注意力模块暴力涨点(附代码+修改教程)
  • 九(4).存在指针的引用,不存在引用的指针
  • uniapp-商城-77-shop(8.2-商品列表,地址信息添加,级联选择器picker)
  • window ollama部署模型
  • 2025年主流编程语言全面分析与学习指南