当前位置: 首页 > java >正文

七、CV_模型微调

七、模型微调

1.微调

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型(预训练模型)。——修改预训练模型使他适合你的任务,最重要的是修改输出层
  2. 创建一个新的神经网络模块,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的

  • 当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力
  • 根据当前任务数据集的大小来确定微调的网络层
    • 在数据集较小时,隐藏层的参数可以不进行微调
    • 在数据集较大时,可以将隐藏层划开,里面的参数也可以进行变化

2.热狗识别案例

将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张热狗或者其他事物的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗

import tensorflow as tf
import numpy as np

(1)获取数据集

  • batch_size在读取数据和模型训练的时候均可以进行设置

通过以下方法读取图像文件,该方法以文件夹路径为参数,生成经过图像增强后的结果,并产生batch数据:

flow_from_directory(self,directory, # 目标文件夹路径,对于每一个类对应一个子文件夹,# 该文件夹中任何JPG,PNG,BNP,PPM的图片都可以读取 target_size = (256, 256), # 默认为(256,256),图像将被resize成该尺寸color_mode = 'rgb',classes = None,class_mode = 'categorical',batch_size = 32, # 默认为32shuffle = True, # 是否打乱数据,默认为Trueseed = None,save_to_dir = None)

创建两个tf.keras.preprocessing.image.ImageDataGenerator示例来分别读取训练数据集和测试数据集中的所有图像文件。将训练集图片全部处理为高宽均为224像素的输入。此外,我们对RGB三个颜色通道的数值做标准化。

注意:

class_modelabel 的形状含义
"binary"(32,)二分类(每个样本是 0 或 1)
"categorical"(32, 2)独热编码([1, 0] 或 [0, 1])
"sparse"(32,)多类整数标签(类似 binary)
None无标签仅返回图像,无监督学习时使用

(2)模型构建与训练

实例化预训练数据集(tf.keras.appilcation)------>模型调整(调整输出层,并设置层是否可训练)

  • 我们使用在ImageNet数据集上的预训练模型的ResNet-50作为源模型。这里指定weights = 'imagenet’来自动下载并加载预训练的模型参数。
  • Keras应用程序(keras.applications)是具有预先训练权值的固定框架,该类封装了很多重量级的网络架构

实现时实例化模型架构:

  • 利用tf.keras中的application实现迁移学习
tf.keras.application.ResNet50(include_top = True,  # 是否包含顶层的全连接层(默认为True)weights = 'imagenet', # None代表随机初始化,'imagenet'代表加载在ImageNet上预训练的权重input_tensor = None, # 如果你已经用 tf.keras.Input() 创建了输入层,这里可以传入它;# 一般用于自定义模型结构input_shape = None, # 可选,输入尺寸元组,仅当include_top = False时有效,否则输入形状必须是(224,224,3)(channels_last格式)# 或(3,224,224)(channels_first格式)。它必须为3个输入通道,且高宽必须不小于32pooling = None, # 当 include_top=False 时,是否添加全局池化classes = 1000,**kwargs
)
  • include_top
    • include_top = True, 模型会包含原始 ResNet50 在 ImageNet 上训练的最后三层全连接分类头(avg_poolfc1000 → softmax 输出 1000 类)
    • include_top = False, 就不会包含这些顶层结构,适合迁移学习时接上你自己的分类层。
  • pooling
    • 如果为 None:输出为卷积特征图(feature map),形状类似 (batch, 7, 7, 2048)
    • 'avg':加一层 GlobalAveragePooling2D,输出为 (batch, 2048)
    • 'max':加一层 GlobalMaxPooling2D,输出为 (batch, 2048)
  • classes(输出类别数量)
    • 只有当 **include_top=True** 时有效
    • 用于设置最终全连接层的输出维度。

在该案例中使用resNet50预训练模型架构模型:

# 加载预训练模型
ResNet50 = tf.keras.applications.ResNet50(weights = 'imagenet', input_shape = (224, 224, 3))
# 设置所有层不可训练
for layer in ResNet50.layers:layer.trainable = False# 设置模型
net = tf.keras.models.Squential()
# 预训练模型
net.add(ResNet50)
# 展开
net.add(tf.keras.layers.Flatten())
# 二分类的全连接层
net.add(tf.keras.layers.Dense(2, activation = 'softmax'))

接下来使用之前定义好的ImageGenerator将训练集图片送入ResNet50进行训练

# 模型编译:指定优化器,损失函数,评价指标
net.compile(optimizer = 'adam',loss = 'categorical_crossentropy',metrics = ['accuracy']
)# 模型训练:指定数据,每一个epoch中只运行10个迭代,指定验证数据集
history = net.fit(train_data_gen = True,steps_per_epoch = 10,epochs = 3,validation_data = test_data_gen,  # 验证集validation_step = 10
)
http://www.xdnf.cn/news/17532.html

相关文章:

  • 通过sealos工具在ubuntu 24.02上安装k8s集群
  • DevOps:从GitLab .gitlab-ci.yml 配置文件到CI/CD
  • 第十五讲:set和map
  • WebAssembly技术详解:从浏览器到云原生的高性能革命
  • 本地WSL部署接入 whisper + ollama qwen3:14b 总结字幕增加利用 Whisper 分段信息,全新 Prompt功能
  • 国内外主流大模型深度体验与横向评测:技术、场景与未来展望
  • 生产工具革命:定制开发开源AI智能名片S2B2C商城小程序重构商业生态的范式研究
  • 密码学的数学基础2-Paillier为什么产生密钥对比RSA慢
  • 基于django的宠物用品购物商城的设计与实现
  • Windows安装MySql8.0
  • docker等基础工具使用
  • Linux810 shell 条件判断 文件工具 ifelse
  • 基于多链路智能SD-WAN的船舶智能监控系统安全等级保护实施方案
  • 【工具变量】地市人力资本水平数据集(2003-2023年)
  • 【密码学】7. 数字签名
  • 四、RuoYi-Cloud-Plus 部署时nacos配置服务启动
  • Python 中的 Mixin
  • 区块链密码学简介
  • 第05章 排序与分页
  • 详解Windows(十四)——PowerShell与命令提示符
  • 09 【C++ 初阶】C/C++内存管理
  • 【pyqt5】SP_(Standard Pixmap)的标准图标常量及其对应的图标
  • Vulnhub----Beelzebub靶场
  • 深度学习中基于响应的模型知识蒸馏实现示例
  • Vue 使用element plus组件库提示doesn‘t work properly without JavaScript enabled
  • 【自动化运维神器Ansible】playbook实践示例:HTTPD安装与卸载全流程解析
  • Vue 3.6 Vapor模式完全指南:告别虚拟DOM,性能飞跃式提升
  • [TryHackMe]Challenges---Game Zone游戏区
  • ThingsBoard配置邮件发送保姆级教程(新版qq邮箱)
  • 第二十天:余数相同问题