人工智能学习37-Keras手写识别预测
人工智能学习概述—快手视频
人工智能学习37-Keras手写识别预测—快手视频
测试示例代码
#从keras导入load_model 方法
from keras.saving.save import load_model
#引入keras.utils 包
import keras.utils
#从keras.datasets 引入 mnist 数据集,已经标注过的数据集
from keras.datasets import mnist
#引入图形类库,方便图形显示
import matplotlib.pyplot as plt
#引入numpy类库,方便矩阵操作
import numpy as np
#引入os操作系统类库,操作本地文件和目录
import os
#避免多库依赖警告信息
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
#设置神经网络模型存储目录,当前python源文件所在目录上一级下的saved_models目录
save_dir = '../saved_models'
#如果目录saved_models不存在,新建此目录
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
#神经网络模块名称
model_name = 'numpred_keras_trained_model.h5'
#设置神经网络分类数量,0-9个数字需要10个分类
num_classes = 10
#手写体图片高和宽,像素数
img_rows, img_cols = 28, 28
#从数据集mnist装入训练数据集和测试数据集,mnist提供load_data方法
(x_train, y_train),(x_test, y_test) = mnist.load_data()
#灰度图编码范围0-255,将编码归一化,转化为0-1之间数值
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
#将训练数据集和测试数据集转化为张量(batch,height,width,channel)
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
#将训练和测试标注数据转化为张量(batch,num_classes)
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
#获取测试数据集样本数量
test_num = x_test.shape[0]
#定义获取随机整数的函数,最大数为test_num
def rand_int():
rand = np.random.RandomState(None)
return rand.randint(low=0, high=test_num)
#获取随机整数
n = rand_int()
#神经网络模型所在目录
model_path = os.path.join(save_dir, model_name)
#装载神经网络模型
model = load_model(model_path)
#预算随机获取测试集中10幅图片
pred = model.predict(x_test[n: n+10], 10)
#定义图形界面分为3行10列显示图片
plt.figure(figsize=(10, 3))
#循环显示每幅图片,标注真实值,和网络预测值
for i in range(n, n+10):
#转化整数i取值范围在[0,9]之间
k = i – n
#定义图片输出位置
plt.subplot(1, 10, k+1)
plt.subplots_adjust(wspace=2)
#格式化图片格式
t = x_test[i].reshape(28, 28)
#显示图片
plt.imshow(t, cmap='gray')
if pred[k].argmax() == y_test[i].argmax():
#预测正确时,使用绿色显示真实值和预测值
plt.title('%d,%d'
color='green')
else:
#预测错误时,使用红色显示真实值和预测值
plt.title('%d,%d' % (pred[k].argmax(), y_test[i].argmax()), color='red')
plt.xticks([])
plt.yticks([])
#显示图片窗口
plt.show()
%
运行结果: