当前位置: 首页 > news >正文

迁移学习实战:基于 ResNet18 的食物分类

一、迁移学习简介

迁移学习是一种高效的机器学习方法,它利用在大规模数据集上预训练好的模型,在新的任务上进行微调。这样做的优势十分显著:

  • 加速训练:无需从零开始训练模型,节省大量时间。
  • 提升性能:预训练模型已经学习到了通用的特征表示,能为新任务提供良好的基础。
  • 数据高效:在新任务数据稀缺时,也能取得不错的效果。

二、迁移学习步骤

1. 选择预训练模型和适当的层

通常会选择在大规模图像数据集(如 ImageNet)上预训练的模型,像 VGG、ResNet 等。对于不同的任务,选择的层也有所不同:

  • 若任务是低级特征提取(如边缘检测),适合使用浅层模型的层。
  • 若任务是高级特征相关(如分类),则应选择更深层次的模型。

2. 冻结预训练模型的参数

保持预训练模型的权重不变,只训练新增加的层或者微调部分层。这样做是为了避免预训练模型在新数据集上过度拟合,同时也能减少计算量。

3. 在新数据集上训练新增加的层

在冻结预训练模型参数的情况下,训练新增加的层,使新模型能够适应新的任务,从而提升性能。

4. 微调预训练模型的层

在新层训练完成后,解冻一些已经训练过的层并进行微调,进一步提高模型在新数据集上的性能。

5. 评估和测试

训练完成后,使用测试集对模型进行评估。若模型性能不佳,可调整超参数或更改微调层。

三、基于 ResNet18 的食物分类实战

   使用上节课所说的残差网络的18层结构来对其进行微调,该残差网络结构如下图所示:

此时我们可以发现输入图像的特征大小为3*224*224,输出特征图格式为512*1*1,然后将其进行全连接层处理后变成输入512张特征图,输出1000个预测结果,这个结果的种类太多,我们不需要使用这么多的预测类别,所以当下需要对其微调,调整最后输出时的全连接层输出结果个数及其全连接层中的权重参数。

1. 导入预训练模型

我们选择在 ImageNet 上预训练好的 ResNet18 模型,代码如下:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np# 导入预训练的ResNet18模型
resent_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

2. 冻结预训练模型参数

通过设置参数的requires_grad属性为False,冻结预训练模型的参数,使其在训练过程中不参与梯度更新:

for param in resent_model.parameters():param.requires_grad = False  # 冻结所有预训练模型参数

3. 修改全连接层

原 ResNet18 模型是为 ImageNet 的 1000 类分类任务设计的,我们要将其适配为 20 类食物分类任务,所以需要修改全连接层,并收集需要训练的参数:

in_features = resent_model.fc.in_features  # 获取原全连接层的输入特征数
resent_model.fc = nn.Linear(in_features, 20)  # 替换为输出为20类的全连接层param_to_update = []  # 收集需要训练的参数(仅新的全连接层)
for param in resent_model.parameters():if param.requires_grad:param_to_update.append(param)

4. 自定义数据集类与数据增强

创建food_dataset类来加载食物图像数据,并通过数据增强来提升模型的泛化能力:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label# 数据增强与预处理
data_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'test':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}# 加载训练集和测试集
train_data = food_dataset(file_path=r'train.1txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'test.1txt', transform=data_transforms['test'])# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

train.1txt,test.1txt如下:

5. 定义训练和测试函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model.forward(x)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 40 == 0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1best_acc = 0
acc_s = []
loss_s = []def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model.forward(x)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}\n")acc_s.append(correct)loss_s.append(test_loss)if correct > best_acc:best_acc = correct

6. 模型设备部署与优化器设置

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model = resent_model.to(device)loss_fn = nn.CrossEntropyLoss()  # 多分类损失函数
optimizer = torch.optim.Adam(param_to_update, lr=0.001)  # 仅优化新全连接层参数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 学习率调度器

7. 训练与测试

epochs = 10
for t in range(epochs):print(f"Epoch {t + 1}\n--------------------------")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()test(test_dataloader, model, loss_fn)
print('最优测试结果为:', best_acc)

训练结果如下:

http://www.xdnf.cn/news/1472599.html

相关文章:

  • BYOFF (Bring Your Own Formatting Function)解析(80)
  • GPU集群扩展:Ray Serve与Celery的技术选型与应用场景分析
  • Pinia 两种写法全解析:Options Store vs Setup Store(含实践与场景对比)
  • (3)Seata AT 模式的事务一致性保证机制
  • MySQL慢查询优化策略
  • 洛谷 P2392 kkksc03考前临时抱佛脚-普及-
  • 【C++题解】贪心和模拟
  • Linux设备down机,如何识别是 断电还是软件复位
  • Java笔记20240726
  • 【Day 22】94.二叉树的中序遍历 104.二叉树的最大深度 226.翻转二叉树 101.对称二叉树
  • linux上nexus安装教程
  • 从“下山”到AI引擎:全面理解梯度下降(下)
  • 学习心得分享
  • 【OJ】C++ vector类OJ题
  • 使用国内镜像源解决 Electron 安装卡在 postinstall 的问题
  • 【Python - 类库 - BeautifulSoup】(01)“BeautifulSoup“使用示例
  • ESP-idf注册双服务器配置
  • SemiSAM+:在基础模型时代重新思考半监督医学图像分割|文献速递-深度学习人工智能医疗图像
  • 笔记:现代操作系统:原理与实现(2)
  • CLIP学习
  • 【C++】Vector完全指南:动态数组高效使用
  • Transformer核心—自注意力机制
  • 大批项目经理被迫上前线,酸爽
  • 图片在vue2中引用的方式和优缺点
  • 【数字孪生核心技术】什么是倾斜摄影?
  • 遇到 Git 提示大文件无法上传确实让人头疼
  • SVT-AV1编码器中实现WPP依赖管理核心调度
  • 门控MLP(Qwen3MLP)与稀疏混合专家(Qwen3MoeSparseMoeBlock)模块解析
  • 【开题答辩全过程】以 基于JSP的宠物医院管理系统设计为例,包含答辩的问题和答案
  • LTV-1008-TP1-G 电子元器件 LiteOn光宝 发光二极管 核心解析