当前位置: 首页 > ai >正文

人工智能学习37-Keras手写识别预测

人工智能学习概述—快手视频
人工智能学习37-Keras手写识别预测—快手视频

测试示例代码 
#从keras导入load_model 方法 
from keras.saving.save import load_model 
#引入keras.utils 包 
import keras.utils 
#从keras.datasets 引入 mnist 数据集,已经标注过的数据集 
from keras.datasets import mnist 
#引入图形类库,方便图形显示 
import matplotlib.pyplot as plt 
#引入numpy类库,方便矩阵操作 
import numpy as np 
#引入os操作系统类库,操作本地文件和目录 
import os 
#避免多库依赖警告信息 
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 
#设置神经网络模型存储目录,当前python源文件所在目录上一级下的saved_models目录 
save_dir = '../saved_models' 
#如果目录saved_models不存在,新建此目录 
if not os.path.isdir(save_dir): 
os.makedirs(save_dir) 
#神经网络模块名称 
model_name = 'numpred_keras_trained_model.h5' 
#设置神经网络分类数量,0-9个数字需要10个分类 
num_classes = 10 
#手写体图片高和宽,像素数 
img_rows, img_cols = 28, 28 
#从数据集mnist装入训练数据集和测试数据集,mnist提供load_data方法 
(x_train, y_train),(x_test, y_test) = mnist.load_data() 
#灰度图编码范围0-255,将编码归一化,转化为0-1之间数值 
x_train = x_train.astype('float32') / 255.0 
x_test = x_test.astype('float32') / 255.0 
#将训练数据集和测试数据集转化为张量(batch,height,width,channel) 
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 
#将训练和测试标注数据转化为张量(batch,num_classes) 
y_train = keras.utils.to_categorical(y_train, num_classes) 
y_test = keras.utils.to_categorical(y_test, num_classes) 
#获取测试数据集样本数量 
test_num = x_test.shape[0] 
#定义获取随机整数的函数,最大数为test_num 
def rand_int(): 
rand = np.random.RandomState(None) 
return rand.randint(low=0, high=test_num) 
#获取随机整数 
n = rand_int() 
#神经网络模型所在目录 
model_path = os.path.join(save_dir, model_name) 
#装载神经网络模型 
model = load_model(model_path) 
#预算随机获取测试集中10幅图片 
pred = model.predict(x_test[n: n+10], 10) 
#定义图形界面分为310列显示图片 
plt.figure(figsize=(10, 3)) 
#循环显示每幅图片,标注真实值,和网络预测值 
for i in range(n, n+10): 
#转化整数i取值范围在[0,9]之间 
k = i – n 
#定义图片输出位置 
plt.subplot(1, 10, k+1) 
plt.subplots_adjust(wspace=2) 
#格式化图片格式 
t = x_test[i].reshape(28, 28) 
#显示图片 
plt.imshow(t, cmap='gray') 
if pred[k].argmax() == y_test[i].argmax(): 
#预测正确时,使用绿色显示真实值和预测值 
plt.title('%d,%d' 
color='green') 
else: 
#预测错误时,使用红色显示真实值和预测值 
plt.title('%d,%d' % (pred[k].argmax(), y_test[i].argmax()), color='red') 
plt.xticks([]) 
plt.yticks([]) 
#显示图片窗口 
plt.show() 
% 
运行结果:

在这里插入图片描述

http://www.xdnf.cn/news/14406.html

相关文章:

  • 对于数据库触发器自动执行的理解
  • Java类的继承
  • Luckfox Pico Pi RV1106学习<3>:支持IMX415摄像头
  • BeckHoff <---> Keyence (MD-X)激光 刻印机 Profient 通讯
  • Elasticsearch:什么是混合搜索?
  • AIGC 基础篇 高等数学篇 06 向量代数与空间解析几何
  • 人月神话-学习记录
  • SQL Developer 表复制
  • Python安装与使用教程
  • Maven在依赖管理工具方面的内容
  • Java多线程通信:wait/notify与sleep的深度剖析(时序图详解)
  • Spring是如何实现有代理对象的循环依赖
  • 【SQLAlchemy系列】 SQLAlchemy 中的多条件查询:or*与 in*操作符
  • 智能土木通 - 土木工程专业知识问答系统02-RAG检索模块搭建
  • AC耦合与DC耦合
  • 体验AI智能投资!AI Hedge Fund了解一下
  • Java可变参数方法的常见错误与最佳实践
  • hyper-v虚拟机使用双屏
  • iOS —— UI(2)
  • Spring Cloud 所有组件全面总结
  • 「AI大数据」| 智慧公路大数据运营中心解决方案
  • Java类加载器与双亲委派模型深度解析
  • DNS递归查询
  • BOLL指标
  • Oracle21cR3之客户端安装错误及处理方法
  • 第11章 结构 笔记
  • 华为OD-2024年E卷-小明周末爬山[200分] -- python
  • 亚马逊ASIN: B0DNTQ2YNT数据深度解析报告
  • 3.创建数据库
  • STM32103CBT6显示ST7789通过SPI方式显示柬埔寨文