当前位置: 首页 > news >正文

Keras模型保存、加载介绍

目录

    • 前言:
    • 保存格式
    • 保存模型、参数和加载模型
    • 示例
    • 总结

前言:

在TensorFlow中,保存和加载模型是机器学习工作流程中的重要步骤。这不仅有助于持久化训练好的模型以便后续使用,还可以实现模型的版本控制、部署和服务。

保存格式

TensorFlow 提供了多种方式来保存和读取模型,主要分为两种格式:SavedModel和Keras的HDF5格式。

使用SavedModel格式
SavedModel是TensorFlow推荐的保存模型的方式。它保存整个模型,包括权重、架构、优化器状态等,并且支持 TensorFlow Serving 等工具。SavedModel 可以在不同平台上使用,并且可以恢复到任何语言的 TensorFlow API 中。

使用 HDF5 格式(Keras 模型)
HDF5 是一种二进制文件格式,适合保存大型数组数据,如神经网络的权重。Keras提供了简单的方法来保存和加载HDF5格式的模型。注意,HDF5文件只保存模型的架构和权重,不保存优化器的状态和其他配置。

保存模型、参数和加载模型

keras 保存成hdf5文件, 1,保存模型和参数;2, 只保存参数

保存模型和参数

##save
callback ModelCheckpoint

只保存参数

##save_weights
callback ModelCheckpoint save_weights_only=True

示例

##导包
from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt## 加载数据
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]##print(x_valid.shape, y_valid.shape)
##print(x_train.shape, y_train.shape)
##print(x_test.shape, y_test.shape)## 标准化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(55000, -1))
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(5000, -1))x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(10000, -1))## 创建模型model = keras.models.Sequential([keras.layers.Dense(512, activation='relu', input_shape=(784,)),keras.layers.Dense(256, activation='relu'),keras.layers.Dense(128, activation='relu'),keras.layers.Dense(10, activation='softmax'),
])model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])## 训练import oslogdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
callbacks = [keras.callbacks.TensorBoard(logdir),keras.callbacks.ModelCheckpoint(output_model_file, save_best_only=True,save_weights_only=True),keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)
]
history = model.fit(x_train_scaled, y_train, epochs=10, validation_data=(x_valid_scaled, y_valid),callbacks=callbacks)##保存模型
output_model_file2 = os.path.join(logdir, 'fashion_mnist_model.h5')
# 另一种保存参数的方法
##model.save_weights(os.path.join(logdir, 'fashin_mnist_weights_2.h5'))
model.save(output_model_file2)print(model.evaluate(x_valid_scaled, y_valid))##加载模型
model2 = keras.models.load_model(output_model_file2)print(model2.evaluate(x_valid_scaled, y_valid))

结果如下:

[0.3386252820491791, 0.8938000202178955] 
[0.3386252820491791, 0.8938000202178955]

总结

SavedModel:推荐用于生产环境,因为它保存了完整的模型信息,并且具有良好的跨平台兼容性。
HDF5:适用于简单的模型保存和加载需求,特别是当你需要与旧版本的 TensorFlow 或其他库兼容时。

http://www.xdnf.cn/news/215029.html

相关文章:

  • 技术驱动与模式创新:开源AI大模型与S2B2C商城重构零售生态
  • 在 MySQL 中建索引时需要注意哪些事项?
  • 使用Spring Boot实现WebSocket广播
  • 二叉树左叶子之和(后序遍历,递归求和)
  • VScode与远端服务器SSH链接
  • NS-SWIFT微调Qwen3
  • Electron Forge【实战】桌面应用 —— 将项目配置保存到本地
  • 【含文档+PPT+源码】基于微信小程序的乡村振兴民宿管理系统
  • BLE技术,如何高效赋能IoT短距无线通信?
  • 【展位预告】正也科技将携营销精细化管理解决方案出席中睿营销论坛
  • 数据库系统概论|第三章:关系数据库标准语言SQL—课程笔记7
  • Unity Audio DSP应用与实现
  • C++多线程与锁机制
  • JavaScript函数声明大比拼
  • yolov8使用
  • 10 基于Gazebo和Rviz实现导航仿真,包括SLAM建图,地图服务,机器人定位,路径规划
  • BIM(建筑信息模型)与GIS(地理信息系统)的融合的技术框架、实现路径与应用场景
  • 【MCP Node.js SDK 全栈进阶指南】高级篇(2):MCP高性能服务优化
  • MCP 协议 ——AI 世界的 “USB-C 接口”:从认知到实践的全面指南
  • 源码角度分析 sync.map
  • Silvaco仿真中victory process的蒙特卡洛(Monte Carlo)离子注入
  • [4-06-09].第10节:自动配置- 分析@SpringBootApplication启动类
  • github使用记录
  • Redis分布式锁使用以及对接支付宝,paypal,strip跨境支付
  • 第十六届蓝桥杯大赛网安组--几道简单题的WP
  • HTTP协议重定向及交互
  • 运放参数汇总
  • mac word接入deepseek
  • LVGL -窗口操作
  • Linux/AndroidOS中进程间的通信线程间的同步 - 管道和FIFO