当前位置: 首页 > news >正文

MNIST 手写数字识别模型分析

功能概述

这段代码实现了一个基于TensorFlow和Keras的MNIST手写数字识别模型。主要功能包括:

  1. 加载并预处理MNIST数据集
  2. 构建一个简单的全连接神经网络模型
  3. 训练模型并评估其性能
  4. 使用训练好的模型进行预测
  5. 保存和加载模型

代码解析

1. 导入必要的库

import matplotlib
import tensorflow.keras as keras
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from pasta.augment import inline
  • 导入TensorFlow和Keras用于构建和训练神经网络
  • 导入NumPy用于数值计算
  • 导入Matplotlib用于数据可视化
  • 从pasta.augment导入inline用于在Jupyter Notebook中直接显示图像

2. 打印TensorFlow版本

print(tf.__version__)

输出当前使用的TensorFlow版本,用于环境检查。

3. 加载MNIST数据集

path = '../doc/mnist.npz'
with np.load(path) as data:x_train, y_train = data['x_train'], data['y_train']x_test, y_test = data['x_test'], data['y_test']
print(x_train[0])
  • 从本地文件加载MNIST数据集
  • 数据集包含训练集(x_train, y_train)和测试集(x_test, y_test)
  • 打印第一个训练样本的像素值

4. 数据可视化

%matplotlib inline
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()
  • 使用Matplotlib显示第一个训练样本的图像
  • cmap=plt.cm.binary设置为黑白显示

5. 打印第一个训练样本的标签

print(y_train[0])

输出第一个训练样本对应的数字标签。

6. 数据归一化

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)
print(x_train[0])
  • 对图像数据进行归一化处理,将像素值缩放到0-1范围
  • 打印归一化后的第一个训练样本

7. 构建神经网络模型

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
  • 创建一个Sequential模型
  • 添加Flatten层将28x28的图像展平为784维向量
  • 添加两个全连接层(Dense),每层128个神经元,使用ReLU激活函数
  • 添加输出层,10个神经元对应10个数字类别,使用Softmax激活函数

8. 编译模型

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
  • 使用Adam优化器
  • 使用稀疏分类交叉熵作为损失函数
  • 使用准确率作为评估指标

9. 训练模型

model.fit(x_train, y_train, epochs=3)
  • 训练模型3个epoch
  • 使用训练数据进行拟合

10. 评估模型

val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)
  • 在测试集上评估模型性能
  • 输出测试损失和准确率

11. 使用模型进行预测

predictions = model.predict(x_test)
print(predictions)
print(np.argmax(predictions[0]))
plt.imshow(x_test[0], cmap=plt.cm.binary)
plt.show()
  • 对测试集进行预测
  • 打印预测结果(概率分布)
  • 使用argmax获取第一个测试样本的预测标签
  • 显示第一个测试样本的图像

12. 保存和加载模型

def softmax_v2(x):return tf.keras.activations.softmax(x)new_model = tf.keras.models.load_model('epic_num_reader.model.keras',custom_objects={'softmax_v2': softmax_v2}
)predictions = new_model.predict(x_test)
print(np.argmax(predictions[0]))
  • 定义一个softmax_v2函数用于兼容性
  • 加载之前保存的模型
  • 使用加载的模型进行预测

总结

这段代码实现了一个简单但有效的MNIST手写数字分类器。主要特点包括:

  1. 使用全连接神经网络结构
  2. 实现了数据预处理和归一化
  3. 达到了较高的测试准确率(约97%)
  4. 包含了模型保存和加载功能
  5. 提供了可视化工具检查数据和预测结果

demo001.ipynb

# 导入 keras 模块
import matplotlib
import tensorflow.keras as keras
# 导入 tensorflow 模块
import tensorflow as tf
# 导入 pasta 模块中的 augment 和 inline 子模块
from pasta.augment import inline# 打印 TensorFlow 的版本
print(tf.__version__)# 指定本地文件路径
path = '../doc/mnist.npz'
# 导入 numpy 模块
import numpy as np
# 从本地加载 MNIST 数据集
with np.load(path) as data:x_train, y_train = data['x_train'], data['y_train']x_test, y_test = data['x_test'], data['y_test']
# 打印训练数据集的第一个样本
print(x_train[0])# 导入 matplotlib.pyplot 模块
import matplotlib.pyplot as plt
# 使用 inline 后,图形将直接显示在 Jupyter Notebook 中
# %matplotlib inline
# 可视化训练数据集的第一个样本
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()# 打印训练标签的第一个样本
print(y_train[0])# 对训练和测试数据进行归一化处理
x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)# 打印归一化后的训练数据集的第一个样本
print(x_train[0])# 可视化归一化后的训练数据集的第一个样本
plt.imshow(x_train[0], cmap=plt.cm.binary)
plt.show()# 创建一个 Sequential 模型
model = tf.keras.models.Sequential()
# 添加一个 Flatten 层,用于将输入数据展平
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
# 添加一个 Dense 层,包含 128 个神经元,使用 ReLU 激活函数
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
# 再添加一个 Dense 层,配置同上
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
# 添加一个 Dense 层,包含 10 个神经元,使用 Softmax 激活函数
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
# 编译模型,指定优化器、损失函数和评估指标
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=3)
# 评估模型
val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)# 使用模型进行预测
predictions = model.predict(x_test)
print(predictions)# 导入 numpy 模块
import numpy as np# 打印第一个测试样本的预测标签
print(np.argmax(predictions[0]))# 可视化第一个测试样本
plt.imshow(x_test[0], cmap=plt.cm.binary)
plt.show()# 保存模型
def softmax_v2(x):# 将 softmax_v2 映射到标准 softmaxreturn tf.keras.activations.softmax(x)# 加载之前保存的模型
new_model = tf.keras.models.load_model('epic_num_reader.model.keras',custom_objects={'softmax_v2': softmax_v2}
)# 使用加载的模型进行预测
predictions = new_model.predict(x_test)
# 打印第一个测试样本的预测标签
print(np.argmax(predictions[0]))
http://www.xdnf.cn/news/1180801.html

相关文章:

  • 秋叶sd-webui频繁出现生成后无反应的问题
  • 【Web APIs】JavaScript 节点操作 ⑧ ( 删除节点 - removeChild 函数 | 删除节点 - 代码示例 | 删除网页评论案例 )
  • 算法竞赛阶段二-数据结构(34)数据结构链表STL vector
  • 【PyTorch】图像二分类项目-部署
  • Spring Boot 3整合Spring AI实战:9轮面试对话解析AI应用开发
  • HttpServletRequest深度解析:Java Web开发的核心组件
  • PyTorch数据选取与索引详解:从入门到高效实践
  • Vue3 面试题及详细答案120道(91-105 )
  • 开立医疗2026年校园招聘
  • 论文复现-windows电脑在pycharm中运行.sh文件
  • 工具篇之开发IDEA插件的实战分享
  • C# 方法执行超时策略
  • 处理URL请求参数:精通`@PathVariable`、`@RequestParam`与`@MatrixVariable`
  • Lua元表(Metatable)
  • Python 使用环境下编译 FFmpeg 及 PyAV 源码(英特尔篇)
  • TDengine 转化类函数 TO_CHAR 用户手册
  • 【数字IC验证学习------- SOC 验证 和 IP验证和形式验证的区别】
  • 借助 VR 消防技术开展应急演练,检验完善应急预案​
  • 数据库底层索引讲解-排序和数据结构
  • 主流 BPM 厂商产品深度分析与选型指南:从能力解析到场景适配
  • 基于深度学习的CT图像3D重建技术研究
  • Python-初学openCV——图像预处理(二)
  • MySQL 表的操作
  • 大模型Prompt优化工程
  • Shell的正则表达式
  • JVM原理及其机制(二)
  • Web前端:JavaScript findIndex⽅法
  • MySQL数据库迁移至国产数据库测试案例
  • Spring MVC 统一响应格式:ResponseBodyAdvice 从浅入深
  • redis常用数据类型