当前位置: 首页 > java >正文

如何保存训练的最优模型和使用最优模型文件

一 保存最优模型

主要就是我们在for循环中加上一个test测试,并且我还在test函数后面加上了返回值,可以返回准确率,然后每次进行一次对比,然后取大的。然后这里有两种保存方式,一种是保存了整个模型,另一个是保存了模型参数。

1 仅保存模型参数

torch.save(model.state_dict(),'best_model.pth')

然后后面我们使用的时候

model =torch.load('best1.pth')#
model.to(device)
model.load_state_dict(torch.load("best.pth", map_location=device))
model.eval()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
test(test_dataloader,model,loss_fn)

注意这里要设置eval模式,因为我们要保证我们的模型参数不再变化了。

2 保存整个模型

torch.save(model,'best1.pth')

在调用的时候

model = torch.load('best1.pth', map_location=torch.device('cuda'))
model.eval()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
test(test_dataloader,model,loss_fn)

直接调用就好。

注意这两种必须要有定义好的网络,不然无法运行(保存整个网络也要定于一个完全相同的网络)。

完整代码

epochs=20
for i in range(epochs):print(f"Epoch {i+1}")train(train_dataloader,model,loss_fn,optimizer)corrects = test(test_dataloader,model,loss_fn)accuracy_list.append(corrects)if corrects>best_acc:print(f"Best Accuracy: {corrects}%")best_acc=corrects#第一种# torch.save(model.state_dict(),'best_model.pth')#第二种torch.save(model,'best1.pth')

完整代码含网络

import numpy as np
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transformsclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(  # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=3,  # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数),out_channels=16,  # 要得到多少个特征图,卷积核的个数kernel_size=5,  # 卷积核人小,5*5stride=1,  # 步长padding=2  # 填充值),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),)self.out = nn.Linear(64 * 64 * 64, 20)  # 全连接层得到的结果def forward(self, x):  # 前向传播,你得告诉它 数据的流向 是神经网络层连接起来,函数名称不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # view和reshape是一样的作用,但此处是tensor形式output = self.out(x)return outputdata_transform={# 'train': transforms.Compose([#     # 调整图像大小为300x300像素#     transforms.Resize([256, 256]),##     # # 随机旋转:-45到45度之间随机选择角度#     # transforms.RandomRotation(45),#     # ##     # # # 从中心裁剪出256x256的区域#     # transforms.CenterCrop([256, 256]),#     ##     # # 随机水平翻转:以50%的概率进行水平镜像#     # transforms.RandomHorizontalFlip(p=0.5),#     ##     # # 随机垂直翻转:以50%的概率进行垂直镜像#     # transforms.RandomVerticalFlip(p=0.5),#     ##     # # # 颜色抖动:随机调整亮度、对比度、饱和度和色调#     # # transforms.ColorJitter(#     # #     brightness=0.2,    # 亮度变化幅度为20%#     # #     contrast=0.1,      # 对比度变化幅度为10%#     # #     saturation=0.1,    # 饱和度变化幅度为10%#     # #     hue=0.1            # 色调变化幅度为10%#     # # ),#     # ##     # # # 随机灰度化:以10%的概率将图像转换为灰度图#     # transforms.RandomGrayscale(p=0.1),##     # 将PIL图像转换为PyTorch张量,并自动归一化到[0,1]范围#     transforms.ToTensor(),##     # 标准化:使用ImageNet数据集的均值和标准差进行标准化#     transforms.Normalize(#         [0.485, 0.456, 0.406],  # 均值(R, G, B通道)#         [0.229, 0.224, 0.225]   # 标准差(R, G, B通道)#     )# ]),# 验证/测试数据的预处理(通常不需要数据增强)'test': transforms.Compose([transforms.Resize([256, 256]),# transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}class food_dataset(Dataset):def __init__(self, root, transform=None):super().__init__()self.root = rootself.transform = transformself.images = []self.labels = []with open(root,encoding='utf-8') as f:samples = [i.strip().split() for i in f.readlines()]for img_path,label in samples:self.images.append(img_path)self.labels.append(label)def __len__(self):return len(self.images)def __getitem__(self, index):image=Image.open(self.images[index]).convert('RGB')if self.transform:image=self.transform(image)label = self.labels[index]# print(label)label = torch.from_numpy(np.array(label,dtype=np.int64))# print(label)return image, labeldef test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()batch_size_num=1loss,correct=0,0with torch.no_grad():for X, y in test_dataloader:X,y=X.to(device),y.to(device)pred = model(X)loss = loss_fn(pred,y)+losscorrect += (pred.argmax(1) == y).type(torch.float).sum().item()loss/=num_batchescorrect/=sizeprint(f'Test result: \n Accuracy: {(100*correct)}%,Avg loss: {loss}')device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")test_data=food_dataset('test_data',transform=(data_transform['test']))
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)# model =CNN()
# model.to(device)
# model.load_state_dict(torch.load("best.pth"))
model=torch.load('best.pt')
model.eval()
loss_fn = nn.CrossEntropyLoss()
test(test_dataloader,model,loss_fn)

http://www.xdnf.cn/news/19900.html

相关文章:

  • 【wpf】WPF开发避坑指南:单例模式中依赖注入导致XAML设计器崩溃的解决方案
  • SpringBoot注解生效原理分析
  • AI落地新趋势:美林数据揭示大模型与小模型的协同进化论
  • Java中 String、StringBuilder 和 StringBuffer 的区别?
  • 小皮80端口被NT内核系统占用解决办法
  • 期货反向跟单—从小白到高手的进阶历程 七(翻倍跟单问题)
  • 【Java】对于XML文档读取和增删改查操作与JDBC编程的读取和增删改查操作的有感而发
  • 加解密安全-侧信道攻击
  • Python分布式任务队列:万级节点集群的弹性调度实践
  • Unity 枪械红点瞄准器计算
  • linux内核 - 服务进程是内核的主要责任
  • dockerfile文件的用途
  • 机器能否真正语言?人工智能NLP面临的“理解鸿沟与突破
  • 键盘上面有F3,四,R,F,V,按下没有反应,维修记录
  • Echo- Go Web Framework的介绍
  • MCP over SSE 通信过程详解:双通道架构下的高效对话
  • 关于牙科、挂号、医生类小程序或管理系统项目 项目包含微信小程序和pc端两部分
  • 《计算机网络安全》实验报告一 现代网络安全挑战 拒绝服务与分布式拒绝服务攻击的演变与防御策略(1)
  • createrepo生成yum仓库元数据xml文件
  • 【机器学习学习笔记】逻辑回归实现与应用
  • 微信小程序预览和分享文件
  • AI生成内容的版权迷局:GPT-4输出的“创意”版权风险与规避之道
  • 解决服务器 DNS 解析失败,从这几步排查开始
  • MiniCPM-V 4.5 模型解析
  • 代码随想录算法训练营第二天| 209.长度最小的子数组
  • 变频器实习DAY42 VF与IF电机启动方式
  • 开源网络流量分析利器:tproxy
  • 嵌入式 - 硬件:51单片机(2)
  • daily notes[9]
  • 校园外卖点餐系统(代码+数据库+LW)