当前位置: 首页 > news >正文

鱼书第三章代码MNIST

一开始想一边学习一边敲代码,搞了半天不知道怎么弄,在网上参考了一些

文件结构参考的这个:

鱼书P70--mnist.py的导入和应用-CSDN博客

内容参考的这个:

调用数据集mnist(下载+调用全攻略)_mnist数据集下载-CSDN博客

一开始在b站找了个视频,通过运行这个main.py代码, download这一步下载了数据集,然后直接复制到我想要的那个dataset的目录中就行

import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import os
from PIL import Image
import numpy as np
transform = transforms.Compose([transforms.ToTensor(), #转为tensor,范围改为0-1transforms.Normalize((0.1307,), (0.3081,))]) #预处理#训练数据和测试数据
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_data = MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
#print(train_data[0])
#train_data[0][0].showclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = nn.Linear(28*28, 256)self.linear2 = nn.Linear(256, 64)self.linear3 = nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 28*28)#变形x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))x = torch.relu(self.linear3(x))return xmodel = Model()
criterion = nn.CrossEntropyLoss()#交叉熵损失,相当于softmax+log+nll loss
optimizer = torch.optim.SGD(model.parameters(), 0.2)if os.path.exists('./model/model.pkl'):model.load_state_dict(torch.load('./model/model.pkl'))#加载保存模型的参数def train(epoch):for index,data in enumerate(train_loader):input,target = data#input为输入,target为标签optimizer.zero_grad()y_prediction = model(input)loss = criterion(y_prediction, target)loss.backward()optimizer.step()if index % 100 == 0:#每一百次保存一次模型,打印损失torch.save(model.state_dict(), './model/model.pkl')torch.save(optimizer.state_dict(), './model/optimizer.pkl')print(loss.item())def test():correct = 0#正确的个数total = 0#总数with torch.no_grad():#测试不用计算梯度for data in test_loader:input, target = dataoutput = model(input)#output输出10个预测取值,其中最大的即为预测的数_, predict = torch.max(output.data, 1)#返回一个元组,第一个为最大值,第二个为最大值的下标——>最主要的就是获取这个下标total += target.size(0)correct+=(predict==target).sum().item()print(correct/total)if __name__ == '__main__':for i in range(2):train(i)test()

然后mnist.py文件:

# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as np# url_base = 'https://ossci-datasets.s3.amazonaws.com/mnist/'  # mirror site
key_file = {'train_img': 'train-images-idx3-ubyte.gz','train_label': 'train-labels-idx1-ubyte.gz','test_img': 't10k-images-idx3-ubyte.gz','test_label': 't10k-labels-idx1-ubyte.gz'
}# 将下载好的数据集放在C:\DeepLearning\dataset路径下
dataset_dir = 'D:\pycharm\MINIST\dataset'
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784# 注释掉下载
# def _download(file_name):
#     file_path = dataset_dir + "/" + file_name#     if os.path.exists(file_path):
#         return#     print("Downloading " + file_name + " ... ")
#     urllib.request.urlretrieve(url_base + file_name, file_path)
#     print("Done")# def download_mnist():
#     for v in key_file.values():
#        _download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] = _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():# download_mnist()                                取消下载dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label :one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])if __name__ == '__main__':init_mnist()

直接运行自己的代码,比如:

import sys,os
sys.path.append(os.pardir)
from dataset.mnist import load_mnist(x_train, t_train), (x_test, t_test) = load_mnist( flatten=True, normalize=False)print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)

 的时候,会出现问题:

Converting train-images-idx3-ubyte.gz to NumPy Array ...
Traceback (most recent call last):
  File "D:\pycharm\MINIST\chapter3\MNIST.py", line 5, in <module>
    (x_train, t_train), (x_test, t_test) = load_mnist( flatten=True, normalize=False)
  File "D:\pycharm\MINIST\dataset\mnist.py", line 111, in load_mnist
    init_mnist()
  File "D:\pycharm\MINIST\dataset\mnist.py", line 80, in init_mnist
    dataset = _convert_numpy()
  File "D:\pycharm\MINIST\dataset\mnist.py", line 70, in _convert_numpy
    dataset['train_img'] = _load_img(key_file['train_img'])
  File "D:\pycharm\MINIST\dataset\mnist.py", line 60, in _load_img
    with gzip.open(file_path, 'rb') as f:
  File "C:\Users\229\anaconda3\envs\yolov5s\lib\gzip.py", line 58, in open
    binary_file = GzipFile(filename, gz_mode, compresslevel)
  File "C:\Users\229\anaconda3\envs\yolov5s\lib\gzip.py", line 173, in __init__
    fileobj = self.myfileobj = builtins.open(filename, mode or 'rb')
FileNotFoundError: [Errno 2] No such file or directory: 'C:\\DeepLearning\\dataset/train-images-idx3-ubyte.gz'

路径出错,他要求保存到c的路径下,然后我保存到了d的路径下,他在c找不到,所以我就把这个路径改成了我的路径 

# 将下载好的数据集放在C:\DeepLearning\dataset路径下
dataset_dir = 'D:\pycharm\MINIST\dataset'
save_file = dataset_dir + "/mnist.pkl"

可以显示结果

 

http://www.xdnf.cn/news/980011.html

相关文章:

  • LVDS系列16:Xilinx 7系输出延迟ODELAYE2
  • AI实用特性
  • 使用R进行数字信号处理:婴儿哭声分析深度解析
  • Anaconda 迁移搭建完成的 conda 环境到另一台设备
  • 涨薪技术|Docker容器技术之镜像(image)
  • Object.defineProperty()详解
  • React 18 渲染机制优化:解决浏览器卡顿的三种方案
  • AX620Q上模型部署流程
  • Spring Security是如何完成身份认证的?
  • BUG调试案例十四:TL431/TL432电路发热问题案例
  • Python训练营打卡DAY51
  • 机器学习核心概念速览
  • 基于ElasticSearch的法律法规检索系统架构实践
  • livetalking实时数字人多并发
  • uni-app项目实战笔记1--创建项目和实现首页轮播图功能
  • 告别excel:AI 驱动的数据分析指南
  • elementui使用Layout布局-对齐方式
  • input+disabled/readonly问题
  • Vue3 + TypeScript + Element Plus 表格行按钮不触发 row-click 事件、不触发勾选行,只执行按钮的 click 事件
  • Explore Image Deblurring via Encoded Blur Kernel Space论文阅读
  • 时序数据库IoTDB数据模型建模实例详解
  • Jmeter中变量如何使用?
  • MySQL 三表 JOIN 执行机制深度解析
  • 基础数论一一同余定理
  • Qt 动态插件系统QMetaObject::invokeMethod
  • 【docker】docker registry搭建私有镜像仓库
  • 开源 java android app 开发(十二)封库.aar
  • SD-WAN 技术如何助力工业物联网(IIoT)数据传输?深度解析传统方案对比与应用实践
  • Chrome 优质插件计划
  • 智慧农业物联网实训中心建设方案