当前位置: 首页 > ds >正文

利用迁移学习实现食物分类:基于PyTorch与ResNet18的实战案例

利用迁移学习实现食物分类:基于PyTorch与ResNet18的实战案例

在深度学习领域,训练一个高性能的模型往往需要大量的数据和计算资源。然而,通过迁移学习,我们能够巧妙地利用在大规模数据集上预训练好的模型,将其知识迁移到我们特定的任务中,不仅可以大幅减少训练时间和数据需求,还能取得出色的效果。本文将以食物分类为例,详细介绍如何使用PyTorch和ResNet18进行迁移学习。

一、迁移学习概述

迁移学习的核心思想是将在一个任务(源任务)中学习到的知识,应用到另一个相关任务(目标任务)中。在计算机视觉领域,许多预训练模型,如ResNet、VGG等,已经在大规模图像数据集(如ImageNet)上进行了充分训练,学习到了丰富的图像特征表示。这些预训练模型的底层网络结构能够提取通用的图像特征,如边缘、纹理等,而顶层网络结构则与源任务的类别紧密相关。因此,在目标任务中,我们可以保留预训练模型的底层结构,仅对顶层进行微调,使其适应目标任务的分类需求。

二、食物分类项目实现

1. 环境与库导入

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch import nn
import torchvision.models as models
from PIL import Image
import numpy as np

上述代码导入了项目所需的核心库。torch是PyTorch的核心库,用于构建和训练深度学习模型;DataLoaderDataset用于数据的加载和管理;transforms用于对图像进行预处理;nn是PyTorch的神经网络模块;models包含了各种预训练模型;Image用于处理图像;numpy用于数值计算。

2. 加载预训练模型并调整结构

resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():param.requires_grad = False
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)
params_to_update = []
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)

首先,通过models.resnet18(weights=models.ResNet18_Weights.DEFAULT)加载在ImageNet数据集上预训练好的ResNet18模型。然后,将模型的所有参数的requires_grad属性设置为False,冻结模型的参数,避免在训练过程中对其进行更新。接着,获取原模型全连接层的输入特征个数in_features,并将原全连接层替换为一个新的全连接层,输出维度为20,对应食物分类任务的20个类别。最后,筛选出需要更新的参数,即新添加的全连接层的参数。

3. 数据准备与预处理

food_type = {0: "八宝粥", 1: "巴旦木", 2: "白萝卜", 3: "板栗", 4: "菠萝", 5: "草莓", 6: "蛋", 7: "蛋挞", 8: "骨肉相连",9: "瓜子", 10: "哈密瓜", 11: "汉堡", 12: "胡萝卜", 13: "火龙果", 14: "鸡翅", 15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯条", 19: "炸鸡"}
data_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, labeltraining_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

定义了食物类别字典food_type,以及训练集和验证集的图像预处理操作。训练集的预处理包括调整图像大小、随机旋转、中心裁剪、随机水平和垂直翻转、随机灰度化、转换为张量以及标准化;验证集的预处理相对简单,仅进行调整大小、转换为张量和标准化。

创建自定义的数据集类food_dataset,继承自Dataset类,实现了__init____len____getitem__方法,用于读取数据文件、获取数据集大小以及加载和预处理图像。最后,使用DataLoader将训练集和测试集封装为可迭代的数据加载器,方便在训练和测试过程中按批次获取数据。

4. 模型训练与测试

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = resnet_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)def train(dataloader, model, loss_fn, optimizer):model.train()for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()best_acc = 0
def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeresult = zip(pred.argmax(1).tolist(), y.tolist())for i in result:print(f"当前测试的结果为:{food_type[i[0]]},当前真实的结果为:{food_type[i[1]]}")print(f"Test result:\n Accurracy:{(100 * correct)}%,AVG loss:{test_loss}")test_loss /= num_batchescorrect /= sizeif correct > best_acc:best_acc = correctepoch = 10
acc_s = []
loss_s = []
for i in range(epoch):print(i + 1)train(train_dataloader, model, loss_fn, optimizer)scheduler.step()test(test_dataloader, model, loss_fn)
print('最终训练结果:', best_acc)

首先,根据当前设备是否支持GPU或苹果M系列芯片的GPU,选择合适的计算设备,并将模型移动到该设备上。定义交叉熵损失函数loss_fn、Adam优化器optimizer以及学习率调整策略scheduler

train函数用于模型的训练,在训练过程中,将数据传入设备,进行前向传播计算预测值,计算损失,通过反向传播计算梯度并更新模型参数。test函数用于模型的测试,在测试过程中,将模型设置为评估模式,关闭梯度计算,计算测试集上的损失和准确率,并输出每个样本的预测结果和真实结果。

最后,通过循环进行多个epoch的训练和测试,在每个epoch结束后调整学习率,并记录最佳准确率。

三、总结

通过本次食物分类项目,我们成功地运用迁移学习技术,基于预训练的ResNet18模型完成了特定任务。这种方法不仅减少了训练时间和数据需求,还展示了迁移学习在实际应用中的强大能力。在未来的深度学习项目中,迁移学习将继续发挥重要作用,帮助我们更高效地解决各种复杂的问题。同时,我们还可以进一步探索不同的预训练模型、调整超参数以及优化数据预处理方法,以提升模型的性能。

http://www.xdnf.cn/news/4071.html

相关文章:

  • 【蓝牙协议栈】【BR/EDR】【AVCTP】精讲音视频控制传输协议
  • 分享一个Android中文汉字手写输入法并带有形近字联想功能
  • Baklib驱动企业知识管理AI升级
  • day15 python 复习日
  • 复杂网络系列:第 5 部分 — 社区检测和子图
  • 在写setup时遇到的问题与思考
  • Circular Plot系列(一): 环形热图绘制
  • 《马小帅的Java闯关记》
  • 模型部署与提供服务
  • QpushButton 扩展InteractiveButtonBase
  • k230摄像头初始化配置函数解析
  • nproc命令查看可用核心数量详解
  • [Windows] 智绘教 v20250403a 屏幕批注工具
  • day 12 三种启发式算法:遗传算法、粒子群算法、退火算法
  • 用卷积神经网络 (CNN) 实现 MNIST 手写数字识别
  • Python函数完全指南:从零基础到灵活运用
  • 深度学习中保存最优模型的实践与探索:以食物图像分类为例
  • GTID(全局事务标识符)的深入解析
  • 高翔《视觉SLAM十四讲》中第13讲,单目稠密重建中的RMODE数据集
  • TS 元组
  • 2025年PMP 学习三
  • 游戏开发的TypeScript(4)TypeScript 的一些内置函数
  • TF-IDF算法详解
  • C# 定时器实现
  • 正态分布习题集 · 题目篇
  • 递归算法详解(Java 实现):从原理到高阶应用
  • 类和对象(上)
  • C语言 指针(5)
  • 两台电动缸同步算法
  • n8n 构建一个 ReAct AI Agent 示例