当前位置: 首页 > backend >正文

2. 手写数字预测 gui版

2. 手写数字预测 gui版

  • 背景
  • 1.界面绘制
  • 2.处理图片
  • 3. 加载模型
  • 4. 预测
  • 5.结果
  • 6.一点小问题

在这里插入图片描述

背景

做了手写数字预测的模型,但是老是跑模型太无聊了,就配合pyqt做了一个可视化界面出来玩一下

源代码可以去这里https://github.com/Leezed525/pytorch_toy拿

1.界面绘制

在这里插入图片描述

整个页面布局逻辑很简单,搭建一下就好了

class MainWindow(QMainWindow):def __init__(self):super().__init__()self.net = self.get_net()  # 获取数字预测模型self.setWindowTitle("PyQt 数字预测")self.setGeometry(100, 100, 500, 550)  # 设置主窗口的初始位置和大小,留出空间给按钮self.setFixedSize(500, 550)self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.WindowMaximizeButtonHint)central_widget = QWidget()  # 创建一个中央 QWidgetself.setCentralWidget(central_widget)  # 设置中央 QWidget 为主窗口的中心部件layout = QVBoxLayout(central_widget)  # 为中央 QWidget 创建一个垂直布局# 创建一个水平布局operation_layer = QHBoxLayout()  # 创建一个水平布局用于放置操作区域left_operation_layer = QVBoxLayout()right_operation_layer = QVBoxLayout()self.canvas = DrawingCanvas(self)  # 创建 DrawingCanvas 实例canvas_label = QLabel("请在此处绘制数字")  # 创建一个标签,提示用户在画布上绘制数字canvas_label.setAlignment(Qt.AlignmentFlag.AlignCenter)canvas_label.setStyleSheet("font-size: 20px;")  # 设置标签的样式left_operation_layer.addWidget(canvas_label)  # 将标签添加到左侧操作区域布局中left_operation_layer.addWidget(self.canvas)left_operation_layer.setStretch(0, 1)left_operation_layer.setStretch(1, 10)  # 设置画布的伸缩比例,使其占据更多空间operation_layer.addLayout(left_operation_layer)  # 将左侧操作区域布局添加到操作层布局中# 右侧操作区域self.predict_label = QLabel("预测结果: ")  # 创建一个标签,显示预测结果right_operation_layer.addWidget(self.predict_label)self.predict_digit_labels = []for i in range(10):predict_digit_label = QLabel(f"数字 {i}: 0.00%")  # 创建标签显示每个数字的预测概率self.predict_digit_labels.append(predict_digit_label)  # 将标签添加到列表中for label in self.predict_digit_labels:right_operation_layer.addWidget(label)operation_layer.addLayout(right_operation_layer)  # 将右侧操作区域布局添加到操作层布局中operation_layer.setStretch(0, 10)operation_layer.setStretch(1, 1)layout.addLayout(operation_layer)  # 将操作层布局添加到主布局中# 按钮区布局button_layout = QHBoxLayout()  # 创建一个垂直布局用于放置按钮clear_button = QPushButton("清空画布")  # 清空画布按钮clear_button.clicked.connect(self.canvas.clear_canvas)  # 连接按钮的点击信号到清空画布方法predict_button = QPushButton("预测")  # 清空画布按钮predict_button.clicked.connect(self.predict)  # 连接按钮的点击信号到预测方法button_layout.addStretch(6)button_layout.addWidget(clear_button)button_layout.addWidget(predict_button)layout.addLayout(button_layout)  # 将按钮布局添加到主布局中

其中稍微有点心智压力的区域就是画图区域,这里配合ai然后再自行修改一下就好了,逻辑就是鼠标按住然后绘制,松开后停止绘制。

canvas代码

class DrawingCanvas(QWidget):"""一个自定义的 QWidget 类,用作绘图画布。用户可以在此画布上用鼠标点击并拖动来绘制线条。"""def __init__(self, parent=None):super().__init__(parent)  # 调用父类 QWidget 的构造函数self.setWindowTitle("绘图画布")  # 设置窗口标题self.setGeometry(100, 100, 280, 280)  # 设置窗口的初始位置和大小 (x, y, width, height)self.setMinimumSize(280, 280)# 创建一个 QImage 对象作为绘图缓冲区# 所有的绘图操作都在这个 QImage 上进行,然后整体绘制到屏幕,可以避免闪烁。# QImage.Format.Format_RGB32 是 PyQt6 中推荐的 RGBA 格式,支持透明度。self.image = QImage(self.size(), QImage.Format.Format_RGB32)# 将 QImage 填充为白色。self.image.fill(Qt.GlobalColor.white)self.drawing = False  # 一个布尔标志,指示当前是否正在进行鼠标拖拽绘图self.last_point = QPoint()  # 存储鼠标上次的位置,用于绘制连续的线条# 同样,颜色常量需要通过 Qt.GlobalColor 访问。self.pen_color = Qt.GlobalColor.blackself.pen_size = 20def paintEvent(self, event):"""绘制事件处理函数。每当窗口需要被重新绘制时(例如,首次显示、窗口大小改变、调用 update() 时),Qt 就会自动调用这个方法。"""painter = QPainter(self)  # 创建一个 QPainter 对象,指定在当前 QWidget (self) 上进行绘制# 将 self.image (绘图缓冲区) 的内容绘制到当前 QWidget 的整个矩形区域内。painter.drawImage(self.rect(), self.image, self.image.rect())def mousePressEvent(self, event):# 检查是否是鼠标左键被按下。if event.button() == Qt.MouseButton.LeftButton:self.drawing = True  # 设置绘图标志为 Trueself.last_point = event.pos()  # 记录当前鼠标位置作为线条的起始点def mouseMoveEvent(self, event):"""鼠标移动事件处理函数。当鼠标在窗口内移动时触发。"""# 只有当正在绘图 (self.drawing 为 True) 并且鼠标左键被按住时才执行绘图操作。# event.buttons() 返回当前按下的所有鼠标按钮的位掩码,Qt.MouseButton.LeftButton 用于检查左键是否按下。if self.drawing and event.buttons() & Qt.MouseButton.LeftButton:painter = QPainter(self.image)  # 在 QImage (绘图缓冲区) 上创建 QPainter 进行绘制# 设置画笔的颜色、粗细和样式。painter.setPen(QPen(QColor(self.pen_color), self.pen_size,Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap, Qt.PenJoinStyle.RoundJoin))# 绘制从上次记录的点到当前鼠标位置的直线painter.drawLine(self.last_point, event.pos())self.last_point = event.pos()  # 更新 last_point 为当前鼠标位置,为下一次绘制做准备self.update()  # 请求窗口重绘。这会间接调用 paintEvent,将 QImage 的最新内容显示到屏幕上。def mouseReleaseEvent(self, event):"""鼠标释放事件处理函数。当用户释放鼠标按钮时触发。"""# 检查是否是鼠标左键被释放。if event.button() == Qt.MouseButton.LeftButton:self.drawing = False  # 停止绘图def resizeEvent(self, event):"""窗口大小改变事件处理函数。当窗口大小改变时触发。"""# 如果新窗口的宽度或高度大于当前 QImage 的尺寸,则需要创建一个新的 QImage。if self.width() > self.image.width() or self.height() > self.image.height():new_image = QImage(self.size(), QImage.Format.Format_RGB32)# 填充新图像为白色new_image.fill(Qt.GlobalColor.white)painter = QPainter(new_image)# 将旧图像的内容绘制到新图像上,以保留已有的绘图。painter.drawImage(QPoint(0, 0), self.image)self.image = new_image  # 更新 self.image 为新的 QImageself.update()  # 请求重绘窗口def clear_canvas(self):"""清空画布内容,将整个 QImage 重新填充为白色。"""self.image.fill(Qt.GlobalColor.white)self.update()  # 请求重绘以显示空白画布def set_pen_size(self, size):"""设置画笔粗细。"""self.pen_size = size

2.处理图片

当布局完成后就只需要处理将图片变成输入的过程就好了,先给代码,在讲解

    def get_image(self):"""获取当前画布上的图像数据。返回一个 QImage 对象,包含当前画布的绘图内容。"""image = self.canvas.image# 将图像缩放到 28x28 像素并转换为灰度图scaled_image = image.scaled(28, 28,Qt.AspectRatioMode.IgnoreAspectRatio,  # 不保持宽高比Qt.TransformationMode.SmoothTransformation  # 平滑缩放)# 转换为 8 位灰度图grayscale_image = scaled_image.convertToFormat(QImage.Format.Format_Grayscale8)# 使用 qimage2ndarray.byte_view() 获取 NumPy 数组arr_3d = qimage2ndarray.byte_view(grayscale_image)arr = arr_3d.squeeze()# 将 NumPy 数组转换为 PyTorch 张量tensor_image = torch.from_numpy(arr).float()# --- 关键修正:添加颜色反转和标准化 ---# 1. 将像素值从 [0, 255] 归一化到 [0.0, 1.0]tensor_image = tensor_image / 255.0# 2. 颜色反转:如果你的模型是基于白色数字黑色背景训练的 而画布是黑色数字白色背景,则需要反转颜色tensor_image = 1.0 - tensor_image# 3. 标准化:应用训练时使用的均值和标准差# MNIST 均值和标准差mean = 0.1307std = 0.3081tensor_image = (tensor_image - mean) / std# 添加批次维度和通道维度,使形状变为 (1, 1, 28, 28)tensor_image = tensor_image.unsqueeze(0).unsqueeze(0).cuda()# --- 可视化 PyTorch 张量 ---# 为了可视化,我们先将其恢复到 [0,1] 范围,否则标准化后的值可能很难看# 逆标准化 (用于可视化,不影响模型输入)# visual_tensor = tensor_image * std + mean# # 确保在 [0,1] 范围内# visual_tensor = torch.clamp(visual_tensor, 0.0, 1.0)# plt.figure(figsize=(2, 2))# plt.imshow(visual_tensor.cpu().squeeze().numpy(), cmap='gray')# plt.title("input")# plt.axis('off')# plt.show()return tensor_image

其中有几个注意点
1.
目前的画布是白色的,画笔是黑色,但是mnist数据集的底是黑色的,画笔是白色的,因此需要使用

tensor_image = 1.0 - tensor_image

来将颜色取反,不然跟训练数据不一样模型无法良好运行。
2.
QT中的image是Qimage,转换成numpy代码有点麻烦,我这里图省事直接用了qimage2ndarray库,因此只需一行代码

arr_3d = qimage2ndarray.byte_view(grayscale_image)

就完成了这个操作。
3.
在输入到模型之前,要进行数据预处理,如上面的代码中

        # 3. 标准化:应用训练时使用的均值和标准差# MNIST 均值和标准差mean = 0.1307std = 0.3081tensor_image = (tensor_image - mean) / std

来优化模型效果。

3. 加载模型

这里的预训练权重就直接用了上一篇文章中训练出来的权重,还给她放到cuda上了,不过这么小的模型其实放不放其实都无所谓,没有太大的影响。

    def get_net(self):"""获取数字预测模型。返回一个 DigitCNN 模型实例。"""# 创建并返回一个 DigitCNN 模型实例net = DigitCNN()net.eval()net.cuda()net.load_state_dict(torch.load('./digit_CNN.pth'))return net

4. 预测

这里就没什么好说的了,就是简单地预测然后将结果同步到gui上了。

    def predict(self):"""预测当前画布上绘制的数字。这里可以调用模型进行预测,并更新预测结果标签。"""input = self.get_image()  # 获取当前画布上的图像数据# 使用模型进行预测with torch.no_grad():output = self.net(input)# 获取预测结果self.update_predict_result(output)def update_predict_result(self, output):_, predict = output.max(1)  # 获取预测的数字类别predict = predict.cpu().numpy()[0]# 更新预测结果标签self.predict_label.setText(f"预测结果: {predict}")# 更新每个数字的预测概率probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]for i, label in enumerate(self.predict_digit_labels):label.setText(f"数字 {i}: {probabilities[i] * 100:.2f}%")

5.结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

6.一点小问题

现在模型是可以用了,但是因为Mnist数据集本身的局限性,已经网络也比较小,泛化性能比较差(但是没差到不能用的地步),所以预测结果又是后会比较奇怪,例如:

.在这里插入图片描述
这是mnist数据集中的数据,可以看出这里的0大部分都是上面闭合,导致模型预测奇怪位置的闭合的0会失准。

还有其中的4大部分都是开口的,并没有闭合4上面的开口,导致写一个很标准的4反倒有时候会预测出错,还有其他的一些问题我就不赘述了。

总之如果想要模型想要获得更好的表现,一是可以增强一下模型的能力,第二个我觉得更重的是把数据好好清洗一下,有些数据真的太差了

http://www.xdnf.cn/news/10314.html

相关文章:

  • VMvare 创建虚拟机 安装CentOS7,配置静态IP地址
  • Kubernetes架构与核心概念深度解析:Pod、Service与RBAC的奥秘
  • 算法训练第四天
  • 企业上线ESOP(电子标准操作程序)电子作业指导书,实现车间无纸化,是数字化转型的重要一步
  • ZC-OFDM雷达通信一体化减小PAPR——部分传输序列法(PTS)
  • 利用python工具you-get下载网页的视频文件
  • 学习笔记:3个学习AI路上反复看到的概念:RAG,Langchain,Agent
  • MySql(十)
  • 字符串~~~
  • 【Python训练营打卡】day40 @浙大疏锦行
  • 前端学习(7)—— HTML + CSS实现博客系统页面
  • python魔法函数
  • 《操作系统真相还原》——初探保护模式
  • 使用curlconverter网站快速生成requests请求包
  • 【Docker 新手入门指南】第十五章:常见故障排除
  • pytest 常见问题解答 (FAQ)
  • 头歌java课程实验(学习-Java字符串之正则表达式之元字符之判断字符串是否符合规则)
  • 每日c/c++题 备战蓝桥杯(P1204 [USACO1.2] 挤牛奶 Milking Cows)
  • [蓝桥杯]分考场
  • 【11408学习记录】考研英语写作提分秘籍:2013真题邀请信精讲+万能模板套用技巧
  • 1-Wire 一线式总线:从原理到实战,玩转 DS18B20 温度采集
  • AE已禁用刷新请释放Caps Lock
  • Redis事务详解:原理、使用与注意事项
  • RabbitMQ 高级特性
  • Python打卡训练营Day41
  • C 语言开发中常见的开发环境
  • python打卡day41@浙大疏锦行
  • 【愚公系列】《生产线数字化设计与仿真》006-颜色分类站仿真(配置颜色分类站的气缸和传送带)
  • YOLO系列中的C3模块解析2025.5.31
  • 《重新定义高效微调:QLoRA 4位量化的颠覆式创新解析》