当前位置: 首页 > news >正文

Mnist手写数字

运行实现:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as pltclass Net(torch.nn.Module):#net类神经网络主体def __init__(self):#4个全链接层super().__init__()self.fc1 = torch.nn.Linear(28*28, 64)#输入为28*28尺寸图像self.fc2 = torch.nn.Linear(64, 64)#中间三层都是64个节点self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)#输出为10个数字类别def forward(self, x):#前向传播x = torch.nn.functional.relu(self.fc1(x))#先全连接线性计算,再套上激活函数x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)#输出层用softmax做归一化,log_softmax是为了提高计算稳定性,套上了一个对数函数return xdef get_data_loader(is_train):#导入数据to_tensor = transforms.Compose([transforms.ToTensor()])#导入张量data_set = MNIST("", is_train, transform=to_tensor, download=True)#下载文件,”“里面对应的是下载目录,is_train指定导入训练集还是测试集return DataLoader(data_set, batch_size=15, shuffle=True)#一个批次15张图片,shuffle=true说明数据是随机打乱的,返回数据加载器def evaluate(test_data, net):#评估正确率n_correct = 0n_total = 0with torch.no_grad():for (x, y) in test_data:#取出数据outputs = net.forward(x.view(-1, 28*28))#计算神经网络预测值for i, output in enumerate(outputs):#作比较if torch.argmax(output) == y[i]:#argmax取最大预测概率的序号n_correct += 1#累加正确的n_total += 1return n_correct / n_totaldef main():train_data = get_data_loader(is_train=True)#训练集test_data = get_data_loader(is_train=False)#测试集net = Net()print("initial accuracy:", evaluate(test_data, net))#打印初始网络的正确率,接近0.1optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#以下为pytorch固定写法for epoch in range(3):#epoch是轮次for (x, y) in train_data:net.zero_grad()#初始化output = net.forward(x.view(-1, 28*28))#正向传播loss = torch.nn.functional.nll_loss(output, y)#计算差值,null_loss对数损失函数,为了匹配前面log_softmax的对数运算loss.backward()#反向误差传播optimizer.step()#优化网络参数print("epoch", epoch, "accuracy:", evaluate(test_data, net))for (n, (x, _)) in enumerate(test_data):#抽取4张图像,显示预测结果if n > 3:breakpredict = torch.argmax(net.forward(x[0].view(-1, 28*28)))plt.figure(n)plt.imshow(x[0].view(28, 28),cmap='gray')plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

中间可能会报错误:(libiomp5md.dll问题)

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.

这个处理就是在anaconda文件夹下面搜索libiomp5md.dll,那bin下面的 libiomp5md.dll文件全部修改命名,就像我这样,两个bin文件夹下面的都改了。

 

运行结果:

两轮精确度如下:

4个数字预测图片如下:

http://www.xdnf.cn/news/753463.html

相关文章:

  • Python 中 dpkt 库的详细使用指南(强大的 Python 数据包解析库)
  • AI视频“入驻”手机,多模态成智能终端的新战场
  • 网页自动化部署(webhook方法)
  • 机器学习有监督学习sklearn实战二:六种算法对鸢尾花(Iris)数据集进行分类和特征可视化
  • 【ISP算法精粹】动手实战:用 Python 实现 Bayer 图像的黑电平校正
  • Linux 第三阶段课程:数据库基础与 SQL 应用
  • 量子语言模型——where to go
  • PHP与MYSQL结合中中的一些常用函数,HTTP协议定义,PHP进行文件编程,会话技术
  • CCPC dongbei 2025 I
  • 2025 年 AI 技能的全景解析
  • ●day 2 任务以及具体安排:第一章 数组part02
  • 子串题解——和为 K 的子数组【LeetCode】
  • 进阶日记(一)—LLMs本地部署与运行(更新中)
  • 【机器学习基础】机器学习入门核心:Jaccard相似度 (Jaccard Index) 和 Pearson相似度 (Pearson Correlation)
  • NLP学习路线图(十六):N-gram模型
  • C# 序列化技术全面解析:原理、实现与应用场景
  • 基于大模型预测的寻常型天疱疮诊疗方案研究报告
  • ERP系统中商品定价功能设计:支持渠道、会员与批发场景的灵活定价机制
  • 行业分析---小米汽车2025第一季度财报
  • 基于Python学习《Head First设计模式》第二章 观察者模式
  • 基于 Flickr30k-Entities 数据集 的 Phrase Localization
  • 动态规划第二弹:路径类问题(不同路径,珠宝的最高价值,地下城游戏)
  • rtpmixsound:实现音频混音攻击!全参数详细教程!Kali Linux教程!
  • 五、单元测试-概述入门
  • SQL进阶之旅 Day 10:执行计划解读与优化
  • FFmpeg学习笔记
  • SDL_CreateRendererWithProperties报错Parameter ‘window‘ is invalid
  • Maven概述,搭建,使用
  • leetcode-hot-100 (矩阵)
  • 设计模式——组合设计模式(结构型)