当前位置: 首页 > ds >正文

残差网络 迁移学习对食物分类案例的改进

目录

一.直接修改最后一层

1.导入模型

2.冻结模型所有参数

3.修改最后一层

4.获取需要更新的参数

5.创建模型

6.优化器和调整学习率

7.完整代码

二.增加一层

1.模型获取和冻结参数

2.定义新的网络类

3.创建模型

4.获取需要更新的参数

5.优化器和调整学习率

6.完整代码


本篇我们直接使用ImageNet竞赛第一名何凯明训练好的经典的ResNet模型来改进我们的食物分类案例

而何凯明训练的18层残差网络模型输出是用来分类1000种物体而不是我们食物分类案例的20所以我们是需要再最后的输出层进行调整,这里我们有两种方法:①直接修改最后一层网络的输出②再添加一层网络

一.直接修改最后一层

1.导入模型

torchvision库中有大量的优秀的视觉方面的已经训练好的模型,torch.nn则倾向于自己搭建网络

import torchvision.models as models

我们直接导入何凯明的resnet18模型,注意这里的模型是需要网络下载的

resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

2.冻结模型所有参数

然后我们需要冻结模型的所有参数,避免我们后续训练改变它们

require_grad属性设置为False即可

for param in resnet_model.parameters():#冻结模型所有参数param.requires_grad=False

3.修改最后一层

fc.in_features属性可以获得最后一层全连接层的输入个数

然后我们直接把这个fc层赋值成一个新的全连接层并把输出设置为20

in_feature=resnet_model.fc.in_features
resnet_model.fc=nn.Linear(in_feature,20)

4.获取需要更新的参数

遍历模型所有的参数,由于我们之前已经冻结了所有参数,所以只有我们上一步重新赋值的fc层是可以修改参数的,将fc的所有权重参数放入一个列表中,方便我们后续优化器进行反向传播优化

params_to_update=[]
for param in resnet_model.parameters():if param.requires_grad==True:params_to_update.append(param)

5.创建模型

由于我们这里是迁移学习,所以我们不用向像之前再创建类的的对象,我们已经有了模型的对象

model=resnet_model.to(device)

6.优化器和调整学习率

优化器我们要注意传入的参数是我们需要更新的最后一层的参数

调整学习率我们采用有序的等间隔调整

optimizer=torch.optim.Adam(params_to_update,lr=0.001)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,5,0.5)

其余代码并未做任何的修改完整代码如下:

7.完整代码

import torch
from torch import  nn
from torch.utils.data import Dataset,DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import os
dire={}
def train_test_file(root,dir):f_out=open(dir+'.txt','w')path=os.path.join(root,dir)for root,directories,files in os.walk(path):if len(directories)!=0:dirs=directorieselse:now_dir=root.split('\\')for file in files:path=os.path.join(root,file)f_out.write(path+' '+str(dirs.index(now_dir[-1]))+'\n')dire[dirs.index(now_dir[-1])]=now_dir[-1]f_out.close()
root=r'.\food_dataset2'
train_dir='train'
test_dir='test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():#冻结模型所有参数param.requires_grad=Falsein_feature=resnet_model.fc.in_features
resnet_model.fc=nn.Linear(in_feature,20)params_to_update=[]
for param in resnet_model.parameters():if param.requires_grad==True:params_to_update.append(param)data_transforms={'train':transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),#随机旋转transforms.CenterCrop(224),#保持与迁移的模型一致transforms.RandomHorizontalFlip(p=0.5),#随机水平旋转,概率p=0.5transforms.RandomVerticalFlip(p=0.5),#随机垂直旋转,概率p=0.5transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),#brightness(亮度)contrast(对比度)saturation(饱和度)hue(色调)transforms.RandomGrayscale(p=0.1),#p的概率转换为灰度图,但任然是三个通道,不过三个通道相同transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#标准化,设置通用的均值和标准差]),'valid':#测试集只用标准化transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}
class food_dataset(Dataset):#能通过索引的方式返回图片数据和标签结果def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs_paths=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs_paths.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs_paths)def __getitem__(self, idx):image=Image.open(self.imgs_paths[idx])if self.transform:image=self.transform(image)label=self.labels[idx]label=torch.from_numpy(np.array(label,dtype=np.int64))#label也转化为tensorreturn image,labeltrain_data=food_dataset(file_path='./train.txt',transform=data_transforms['train'])
test_data=food_dataset(file_path='./test.txt',transform=data_transforms['valid'])train_loader=DataLoader(train_data,batch_size=32,shuffle=True)
test_loader=DataLoader(test_data,batch_size=32,shuffle=True)
device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')model=resnet_model.to(device)def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value=loss.item()if batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1
best_acc=0
def test(dataloader,model,loss_fn):global best_accmodel.eval()len_data=len(dataloader.dataset)correct,loss_sum=0,0num_batch=0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss_sum += loss_fn(pred, y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()num_batch+=1loss_avg=loss_sum/num_batchaccuracy=correct/len_dataif accuracy>best_acc:best_acc=accuracyprint(f'Accuracy:{100 * accuracy}%\nLoss Avg:{loss_avg}')return pred.argmax(1)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(params_to_update,lr=0.001)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,5,0.5)epochs=10
for i in range(epochs):print(f'==========第{i + 1}轮训练==============')train(train_loader, model, loss_fn, optimizer)test(test_loader, model, loss_fn)scheduler.step()print(f'第{i + 1}轮训练结束')
print('best_acc=',best_acc)

二.增加一层


我们如何在最后添加一层是通过定义一个类来实现的原理如下:

我们是通过类将迁移学习的模块和新的fc层组合在一起实现在残差网络中添加一层网络

1.模型获取和冻结参数

resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():#冻结模型所有参数param.requires_grad=False

2.定义新的网络类

class ResNet(nn.Module):def __init__(self):super().__init__()self.resnet_=resnet_modelself.fc=nn.Linear(1000,20)def forward(self,x):out=self.resnet_(x)out=self.fc(out)return out

3.创建模型

这里创建模型我们又用到了类的创建

model=ResNet().to(device)

4.获取需要更新的参数

params_to_update=[]
for param in model.parameters():if param.requires_grad==True:params_to_update.append(param)

5.优化器和调整学习率

这里的调整学习率我们采用自适应调整

optimizer=torch.optim.Adam(params_to_update,lr=0.001)
# scheduler=torch.optim.lr_scheduler.StepLR(optimizer,5,0.5)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode='min',factor=0.5,patience=2,verbose=False,threshold=0.0001,threshold_mode='rel',cooldown=0,min_lr=0,eps=1e-08)

6.完整代码

其余代码并无修改,只是train()方法做了一个返回平均损失值的改进用来作为学习率调整的根据指标,完整代码如下:

import torch
from torch import  nn
from torch.utils.data import Dataset,DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms
import torchvision.models as modelsresnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():#冻结模型所有参数param.requires_grad=False
class ResNet(nn.Module):def __init__(self):super().__init__()self.resnet_=resnet_modelself.fc=nn.Linear(1000,20)def forward(self,x):out=self.resnet_(x)out=self.fc(out)return outdata_transforms={'train':transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),#随机旋转transforms.CenterCrop(224),#保持与迁移的模型一致transforms.RandomHorizontalFlip(p=0.5),#随机水平旋转,概率p=0.5transforms.RandomVerticalFlip(p=0.5),#随机垂直旋转,概率p=0.5transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),#brightness(亮度)contrast(对比度)saturation(饱和度)hue(色调)transforms.RandomGrayscale(p=0.1),#p的概率转换为灰度图,但任然是三个通道,不过三个通道相同transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#标准化,设置通用的均值和标准差]),'valid':#测试集只用标准化transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}
class food_dataset(Dataset):#能通过索引的方式返回图片数据和标签结果def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs_paths=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs_paths.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs_paths)def __getitem__(self, idx):image=Image.open(self.imgs_paths[idx])if self.transform:image=self.transform(image)label=self.labels[idx]label=torch.from_numpy(np.array(label,dtype=np.int64))#label也转化为tensorreturn image,labeltrain_data=food_dataset(file_path='./train.txt',transform=data_transforms['train'])
test_data=food_dataset(file_path='./test.txt',transform=data_transforms['valid'])train_loader=DataLoader(train_data,batch_size=32,shuffle=True)
test_loader=DataLoader(test_data,batch_size=32,shuffle=True)
device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')model=ResNet().to(device)
params_to_update=[]
for param in model.parameters():if param.requires_grad==True:params_to_update.append(param)def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1loss_sum=0total_batches = len(dataloader)  # 获取总批次数量for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value=loss.item()loss_sum+=loss_valueif batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1return loss_sum/total_batches
best_acc=0
def test(dataloader,model,loss_fn):global best_accmodel.eval()len_data=len(dataloader.dataset)correct,loss_sum=0,0num_batch=0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss_sum += loss_fn(pred, y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()num_batch+=1loss_avg=loss_sum/num_batchaccuracy=correct/len_dataif accuracy>best_acc:best_acc=accuracyprint(f'Accuracy:{100 * accuracy}%\nLoss Avg:{loss_avg}')return pred.argmax(1)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(params_to_update,lr=0.001)
# scheduler=torch.optim.lr_scheduler.StepLR(optimizer,5,0.5)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode='min',factor=0.5,patience=2,verbose=False,threshold=0.0001,threshold_mode='rel',cooldown=0,min_lr=0,eps=1e-08)epochs=10
for i in range(epochs):print(f'==========第{i + 1}轮训练==============')loss=train(train_loader, model, loss_fn, optimizer)scheduler.step(loss)test(test_loader, model, loss_fn)print(f'第{i + 1}轮训练结束')
print('best_acc=',best_acc)

http://www.xdnf.cn/news/20464.html

相关文章:

  • STL模版在vs2019和gcc中的特殊问题
  • STM32项目分享:基于物联网的健康监测系统设计
  • 基于STM32的智能宠物屋系统设计
  • 人工智能学习:什么是seq2seq模型
  • Java全栈开发工程师的面试实战:从基础到复杂场景的技术探索
  • Compose笔记(四十九)--SwipeToDismiss
  • RabbitMQ工作模式(下)
  • 贪心算法应用:蛋白质折叠问题详解
  • Eureka与Nacos的区别-服务注册+配置管理
  • AI模型测评平台工程化实战十二讲(第一讲:从手工测试到系统化的觉醒)
  • 力扣29. 两数相除题解
  • Qt资源系统学习
  • 【继承和派生】
  • 【Flask】测试平台开发,重构提测管理页面-第二十篇
  • 把装配想象成移动物体的问题
  • java基础学习(五):对象中的封装、继承和多态
  • C++经典的数据结构与算法之经典算法思想:排序算法
  • phpMyAdmin文件包含漏洞复现:原理详解+环境搭建+渗透实战(windows CVE-2014-8959)
  • 吴恩达机器学习(七)
  • 综合安防集成系统解决方案,智慧园区,智慧小区安防方案(300页Word方案)
  • 《2025国赛/高教杯》C题 完整实战教程(代码+公式详解)
  • 关于连接池
  • 【PostgreSQL】如何实现主从复制?
  • 网络原理-
  • 在Ubuntu平台搭建RTMP直播服务器使用SRS简要指南
  • Qt 基础教程合集(完)
  • 分布式数据架构
  • 硬件开发_基于物联网的老人跌倒监测报警系统
  • 数据结构——栈(Java)
  • MySQL数据库约束和设计