当前位置: 首页 > java >正文

Python 基于卷积神经网络手写数字识别

Ubuntu系统:22.04

python版本:3.9

安装依赖库:

pip install tensorflow==2.13 matplotlib numpy -i https://mirrors.aliyun.com/pypi/simple

代码实现:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') / 255
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32') / 255# 构建CNN模型
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
history = model.fit(train_images, train_labels,batch_size=128,epochs=5,verbose=1,validation_data=(test_images, test_labels))# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)
print(f"\n测试准确率: {test_acc:.4f}")# 保存模型
model.save('mnist_cnn_model.keras')
print("模型已保存为 mnist_cnn_model.keras")# 可视化训练过程
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('模型准确率')
plt.ylabel('准确率')
plt.xlabel('训练轮次')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('模型损失')
plt.ylabel('损失')
plt.xlabel('训练轮次')
plt.legend()plt.tight_layout()
plt.savefig('training_history.png')
print("训练过程图表已保存为 training_history.png")# 测试预测
sample_idx = np.random.randint(0, len(test_images))
sample_image = test_images[sample_idx].reshape(1, 28, 28, 1)
prediction = model.predict(sample_image, verbose=0)plt.figure(figsize=(5, 3))
plt.imshow(test_images[sample_idx].reshape(28, 28), cmap='gray')
plt.title(f"真实标签: {test_labels[sample_idx]}\n预测结果: {np.argmax(prediction)}")
plt.axis('off')
plt.savefig('sample_prediction.png')
print(f"样本预测图已保存为 sample_prediction.png\n真实标签: {test_labels[sample_idx]},预测结果: {np.argmax(prediction)}")

http://www.xdnf.cn/news/9868.html

相关文章:

  • (二)视觉——工业镜头(以海康威视为例)
  • 罗马-华为
  • CC攻击的种类与特点解析
  • ElementUI表单验证指南
  • Spring Boot的启动流程,以及各个扩展点的执行顺序
  • AI视频生成加速器:Medeo如何用零门槛技术重塑内容创作
  • 【python爬虫】利用代理IP爬取filckr网站数据
  • UFSH2024 程序化生成 笔记
  • GJOI 5.27 题解
  • 增广拉格朗日时空联合规划ALTRO-iLQR (一)
  • 2.qml使用c++
  • 【C++基础知识】RAII的一个简单示例讲解
  • MySQL8.4组复制
  • SpeedFolding 论文翻译
  • “谁能进,谁不能进?”——用NAC精准控制网络访问
  • JS中class和构造函数的区别
  • Selenium 测试框架 - Kotlin
  • 制造企业搭建AI智能生产线怎么部署?
  • .NET WinForm图像识别二维码/条形码并读取其中内容
  • 01.认识Kubernetes
  • 广告流量监测和IP地址离线库
  • Nexus仓库数据高可用备份与恢复方案(下)
  • 苹果FINDMY和谷歌FIND HUB增强共享位置功能
  • offset 家族和 client 家族
  • 【第4章 图像与视频】4.1 图像的绘制
  • Next.js 布局(Layout)与模板(Template)深度解析:从原理到实战
  • 在VirtualBox中打造高效开发环境:CentOS虚拟机安装与优化指南
  • SQL正则表达式总结
  • Java面试实战:从Spring到大数据的全栈挑战
  • STM32中,如何理解看门狗