当前位置: 首页 > news >正文

pytorch-利用letnet5框架深度学习手写数字识别

LetNet-5

利用letnet5框架深度学习手写数字识别
LeNet-5 项目说明
项目简介

本项目实现了经典的 LeNet-5 卷积神经网络模型,主要用于手写数字识别任务。模型结构包括两个卷积层、两个池化层和三个全连接层,适用于 MNIST 数据集。

项目结构
.
├── model.py # LeNet-5 模型定义
├── plot.py # 数据加载与可视化
├── train.py # 模型训练脚本
├── test.py # 模型测试与可视化
├── best_model.pth # 训练后的最佳模型权重
├── README.md # 项目说明文档

安装依赖

pip install torch torchvision matplotlib

数据加载与预处理

在 plot.py 中,定义了 test_Loader,用于加载 MNIST 测试数据集。数据预处理包括:

将图像转换为 Tensor

标准化图像数据

加载器使用 DataLoader 进行批处理

模型定义

在 model.py 中,定义了 LeNet-5 模型结构。模型包括以下层:

输入层:32x32 灰度图像

C1:卷积层,6 个 5x5 卷积核,输出 28x28 特征图

S2:池化层,2x2 平均池化,输出 14x14 特征图

C3:卷积层,16 个 5x5 卷积核,输出 10x10 特征图

S4:池化层,2x2 平均池化,输出 5x5 特征图

C5:卷积层,120 个 1x1 卷积核,输出 1x1 特征图

F6:全连接层,84 个神经元

输出层:10 个神经元,对应 10 个数字类别

模型训练

在 train.py 中,定义了模型训练过程,包括:

加载训练数据

定义损失函数和优化器

训练模型并保存最佳权重至 best_model.pth

模型测试与可视化

在 test.py 中,定义了模型测试过程:

加载测试数据

加载训练好的模型权重

计算测试准确率

可视化预测结果:

import torch
import matplotlib.pyplot as plt
import modeldef test_model_process(model, test_data, max_visualize=10):test_acc = 0.0test_num = 0visualize_count = 0  # 可视化计数model.eval()with torch.no_grad():for test_x, test_y in test_data:output = model(test_x)pre_label = torch.argmax(output, dim=1)test_acc += torch.sum(pre_label == test_y)test_num += test_x.size(0)# 遍历 batchfor i in range(test_x.size(0)):if visualize_count >= max_visualize:breaklabel = test_y[i].item()result = pre_label[i].item()# 可视化img = test_x[i].squeeze().cpu()  # 去掉 channelplt.imshow(img, cmap='gray')title_color = 'green' if label == result else 'red'plt.title(f"预测值:{result} 真实值:{label}", color=title_color)plt.axis('off')plt.show()# 控制台输出if label == result:print("预测值:", result, "-------", "真实值", label)else:print("预测值:", result, "-----------------------", "真实值", label)visualize_count += 1test_avg_acc = test_acc.item() / test_numprint("测试准确率:", test_avg_acc)

使用方法

训练模型:

python train.py

测试模型并可视化:

python test.py

资源连接链接🔗

http://www.xdnf.cn/news/1369657.html

相关文章:

  • Vue2(七):配置脚手架、render函数、ref属性、props配置项、mixin(混入)、插件、scoped样式
  • 深入解析交换机端口安全:Sticky MAC的工作原理与应用实践
  • 机器视觉学习-day03-灰度化实验-二值化和自适应二值化
  • 【C++】智能指针底层原理:引用计数与资源管理机制
  • 深度学习篇---LeNet-5网络结构
  • 病理软件Cellprofiler使用教程
  • vue2 和 vue3 生命周期的区别
  • 一篇文章拆解Java主流垃圾回收器及其调优方法。
  • LeetCode-22day:多维动态规划
  • 代码随想录Day62:图论(Floyd 算法精讲、A * 算法精讲、最短路算法总结、图论总结)
  • vue2和vue3的对比
  • TensorFlow 深度学习:使用 feature_column 训练心脏病分类模型
  • Day3--HOT100--42. 接雨水,3. 无重复字符的最长子串,438. 找到字符串中所有字母异位词
  • CentOS 7 服务器初始化:从 0 到 1 的安全高效配置指南
  • 肌肉力量训练
  • 木马免杀工具使用
  • 产品经理操作手册(3)——产品需求文档
  • 全链路营销增长引擎闭门会北京站开启倒计时,解码营销破局之道
  • 构建生产级 RAG 系统:从数据处理到智能体(Agent)的全流程深度解析
  • 书生大模型InternLM2:从2.6T数据到200K上下文的开源模型王者
  • word批量修改交叉引用颜色
  • 【SystemUI】新增实体键盘快捷键说明
  • 常用Nginx正则匹配规则
  • ruoyi-vue(十二)——定时任务,缓存监控,服务监控以及系统接口
  • 软件检测报告:XML外部实体(XXE)注入漏洞原因和影响
  • 服务器初始化流程***
  • 在分布式环境下正确使用MyBatis二级缓存
  • 在 UniApp 中,实现下拉刷新
  • Python爬虫: 分布式爬虫架构讲解及实现
  • IjkPlayer 播放 MP4 视频时快进导致进度回退的问题