鱼书第三章代码MNIST
一开始想一边学习一边敲代码,搞了半天不知道怎么弄,在网上参考了一些
文件结构参考的这个:
鱼书P70--mnist.py的导入和应用-CSDN博客
内容参考的这个:
调用数据集mnist(下载+调用全攻略)_mnist数据集下载-CSDN博客
一开始在b站找了个视频,通过运行这个main.py代码, download这一步下载了数据集,然后直接复制到我想要的那个dataset的目录中就行
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import os
from PIL import Image
import numpy as np
transform = transforms.Compose([transforms.ToTensor(), #转为tensor,范围改为0-1transforms.Normalize((0.1307,), (0.3081,))]) #预处理#训练数据和测试数据
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_data = MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
#print(train_data[0])
#train_data[0][0].showclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = nn.Linear(28*28, 256)self.linear2 = nn.Linear(256, 64)self.linear3 = nn.Linear(64, 10)def forward(self, x):x = x.view(-1, 28*28)#变形x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))x = torch.relu(self.linear3(x))return xmodel = Model()
criterion = nn.CrossEntropyLoss()#交叉熵损失,相当于softmax+log+nll loss
optimizer = torch.optim.SGD(model.parameters(), 0.2)if os.path.exists('./model/model.pkl'):model.load_state_dict(torch.load('./model/model.pkl'))#加载保存模型的参数def train(epoch):for index,data in enumerate(train_loader):input,target = data#input为输入,target为标签optimizer.zero_grad()y_prediction = model(input)loss = criterion(y_prediction, target)loss.backward()optimizer.step()if index % 100 == 0:#每一百次保存一次模型,打印损失torch.save(model.state_dict(), './model/model.pkl')torch.save(optimizer.state_dict(), './model/optimizer.pkl')print(loss.item())def test():correct = 0#正确的个数total = 0#总数with torch.no_grad():#测试不用计算梯度for data in test_loader:input, target = dataoutput = model(input)#output输出10个预测取值,其中最大的即为预测的数_, predict = torch.max(output.data, 1)#返回一个元组,第一个为最大值,第二个为最大值的下标——>最主要的就是获取这个下标total += target.size(0)correct+=(predict==target).sum().item()print(correct/total)if __name__ == '__main__':for i in range(2):train(i)test()
然后mnist.py文件:
# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as np# url_base = 'https://ossci-datasets.s3.amazonaws.com/mnist/' # mirror site
key_file = {'train_img': 'train-images-idx3-ubyte.gz','train_label': 'train-labels-idx1-ubyte.gz','test_img': 't10k-images-idx3-ubyte.gz','test_label': 't10k-labels-idx1-ubyte.gz'
}# 将下载好的数据集放在C:\DeepLearning\dataset路径下
dataset_dir = 'D:\pycharm\MINIST\dataset'
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784# 注释掉下载
# def _download(file_name):
# file_path = dataset_dir + "/" + file_name# if os.path.exists(file_path):
# return# print("Downloading " + file_name + " ... ")
# urllib.request.urlretrieve(url_base + file_name, file_path)
# print("Done")# def download_mnist():
# for v in key_file.values():
# _download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] = _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():# download_mnist() 取消下载dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label :one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])if __name__ == '__main__':init_mnist()
直接运行自己的代码,比如:
import sys,os
sys.path.append(os.pardir)
from dataset.mnist import load_mnist(x_train, t_train), (x_test, t_test) = load_mnist( flatten=True, normalize=False)print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
的时候,会出现问题:
Converting train-images-idx3-ubyte.gz to NumPy Array ...
Traceback (most recent call last):
File "D:\pycharm\MINIST\chapter3\MNIST.py", line 5, in <module>
(x_train, t_train), (x_test, t_test) = load_mnist( flatten=True, normalize=False)
File "D:\pycharm\MINIST\dataset\mnist.py", line 111, in load_mnist
init_mnist()
File "D:\pycharm\MINIST\dataset\mnist.py", line 80, in init_mnist
dataset = _convert_numpy()
File "D:\pycharm\MINIST\dataset\mnist.py", line 70, in _convert_numpy
dataset['train_img'] = _load_img(key_file['train_img'])
File "D:\pycharm\MINIST\dataset\mnist.py", line 60, in _load_img
with gzip.open(file_path, 'rb') as f:
File "C:\Users\229\anaconda3\envs\yolov5s\lib\gzip.py", line 58, in open
binary_file = GzipFile(filename, gz_mode, compresslevel)
File "C:\Users\229\anaconda3\envs\yolov5s\lib\gzip.py", line 173, in __init__
fileobj = self.myfileobj = builtins.open(filename, mode or 'rb')
FileNotFoundError: [Errno 2] No such file or directory: 'C:\\DeepLearning\\dataset/train-images-idx3-ubyte.gz'
路径出错,他要求保存到c的路径下,然后我保存到了d的路径下,他在c找不到,所以我就把这个路径改成了我的路径
# 将下载好的数据集放在C:\DeepLearning\dataset路径下
dataset_dir = 'D:\pycharm\MINIST\dataset'
save_file = dataset_dir + "/mnist.pkl"
可以显示结果