当前位置: 首页 > ds >正文

从代码学习深度学习 - 微调 PyTorch 版

文章目录

  • 前言
  • 一、迁移学习与微调概念
  • 二、微调步骤解析
  • 三、实战案例:热狗识别
    • 3.1 数据集准备
    • 3.2 图像增强处理
    • 3.3 加载预训练模型
    • 3.4 模型重构
    • 3.5 差异化学习率训练
    • 3.6 对比实验分析
  • 总结


前言

深度学习模型训练通常需要大量数据,但在实际应用中,我们往往难以获得足够的标记数据。例如,如果我们想构建一个识别不同类型椅子的系统,收集和标记数千甚至数万张椅子图像将耗费大量时间和资金。这种情况下,迁移学习特别是微调(fine-tuning)技术便显示出其强大优势。本文将通过一个热狗识别的实际案例,详细讲解如何在PyTorch中实现微调,帮助读者掌握这一重要技术。注意,本博客只列出了与微调相关的代码,完整代码在下方链接中给出,其中包含了详细的注释。

完整代码:下载链接


一、迁移学习与微调概念

迁移学习是指将从一个任务中学到的知识应用到另一个相关任务中。在计算机视觉领域,我们常常利用在大规模数据集(如ImageNet)上预训练的模型,将其"迁移"到我们的特定任务中。

微调是迁移学习的一种常见方法,它不仅复用预训练模型的架构,还复用其参数,然后通过在目标数据集上继续训练来调整这些参数,使模型适应新任务。这种方法的核心假设是:预训练模型已经学到了通用的特征提取能力,只需要针对新任务做适度调整。

二、微调步骤解析

微调通常包含以下四个关键步骤:

  1. 预训练模型选择:在源数据集(如ImageNet)上训练一个基础模型
  2. 模型结构调整:复制预训练模型的架构和参数(除输出层外)
  3. 输出层替换:添加适合目标任务的新输出层,并随机初始化其参数
  4. 差异化训练:在目标数据集上训练模型,通常对预训练层使用较小学习率,对新添加层使用较大学习率
    在这里插入图片描述

三、实战案例:热狗识别

3.1 数据集准备

首先,我们加载并查看热狗识别的数据集:

# 设置matplotlib在Jupyter Notebook中内嵌显示图表
%matplotlib inline
# 导入必要的库
import os  # 用于处理文件路径
import torch  # PyTorch深度学习框架
import torchvision  # PyTorch视觉库,用于处理图像数据
from torch import nn  # PyTorch神经网络模块
# 导入自定义工具函数,用于显示图像
import utils_for_huitu
# 设置数据目录路径
data_dir = 'hotdog'  # 数据根目录
# 加载训练集图像
# ImageFolder假设数据按类别存放在不同文件夹中
# 文件结构应为:hotdog/train/[类别1]/, hotdog/train/[类别2]/ 等
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
# 加载测试集图像
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
# 获取训练集中的热狗图像样本
# 从训练集的前8张图像中获取图像数据
# train_imgs[i][0]表示第i个样本的图像数据,train_imgs[i][1]是对应的标签
hotdogs = [train_imgs[i][0] for i in range(8)]
# 获取训练集中的非热狗图像样本
# 从训练集的末尾8张图像中获取图像数据
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
# 显示图像
# 将热狗和非热狗图像合并为一个列表,并显示在2行8列的网格中
# scale参数用于调整图像显示的大小
utils_for_huitu.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)

http://www.xdnf.cn/news/1924.html

相关文章:

  • 数据结构初阶:排序
  • 【MySQL专栏】MySQL数据库的复合查询语句
  • Pycharm(六):可变与不可变类型
  • 【时时三省】(C语言基础)循环程序举例
  • 手把手搭建AIGC应用:从图像生成到智能写作全实战
  • 使用 n8n 实现你的第一个爬虫程序:从零到自动化
  • 【金仓数据库征文】- 金融HTAP实战:KingbaseES实时风控与毫秒级分析一体化架构
  • 飞牛 NAS 整机要来了?!
  • C#高级语法--接口
  • 初识HashMap
  • 华为L410上制作内网镜像模板:在客户端配置模板内容
  • 施工配电箱巡检二维码应用
  • 【EDA】EDA中聚类(Clustering)和划分(Partitioning)
  • STM32F103C8T6信息
  • 【金仓数据库征文】-不懂数据库也能看懂!一文解析金仓技术介绍以典型应用
  • 力扣-206.反转链表
  • 2025最新版扣子(Coze)AI智能体应用指南
  • 118. 杨辉三角
  • c++——内部类
  • AI 开发入门之 RAG 技术
  • 解析Mqtt 消息服务质量Qos
  • 2025最新软件测试面试八股文(答案+文档+视频讲解)
  • linux 桌面环境
  • 如何用大模型技术重塑物流供应链
  • 【C++基础知识】C++类型特征组合:`disjunction_v` 和 `conjunction_v` 深度解析
  • linux centOS7.9 No package docker-ce available
  • 解决 Windows10 下 UWP 应用无法使用本地代理
  • Python实现技能记录系统
  • 建筑安全员考试科目有哪些
  • 从梯度消失到百层网络:ResNet 是如何改变深度学习成为经典的?