当前位置: 首页 > news >正文

深度学习——基于卷积神经网络实现食物图像分类之(保存最优模型)

引言

本文将详细介绍如何使用PyTorch框架构建一个完整的食物图像分类系统,包含数据预处理、模型构建、训练优化以及模型保存等关键环节。与上一篇博客介绍的版本相比,本版本增加了模型保存与加载功能,并优化了测试评估流程。

一、项目概述

本项目的目标是构建一个能够识别20种不同食物的图像分类系统。主要技术特点包括:

  1. 简化但高效的数据预处理流程
  2. 三层CNN网络架构设计
  3. 训练过程中自动保存最佳模型
  4. 完整的训练-评估流程实现

二、环境配置

首先确保已安装必要的Python库:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

三、数据预处理

3.1 数据转换设置

我们为训练集和验证集定义了不同的转换策略:

data_transforms = {'train': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),
}

简化说明

3.2 数据集准备
def train_test_file(root, dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots, directories, files in os.walk(path):if len(directories) != 0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()

该函数会生成包含图像路径和标签的文本文件,格式为:

path/to/image1.jpg 0
path/to/image2.jpg 1
...

四、自定义数据集类

我们继承PyTorch的Dataset类实现自定义数据集:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

关键改进

五、CNN模型架构

我们设计了一个三层CNN网络:

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)return self.out(x)

架构特点

  1. 每层包含卷积、ReLU激活和最大池化
  2. 使用padding保持特征图尺寸
  3. 最后通过全连接层输出分类结果

六、训练与评估流程

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()if batch_size_num % 1 == 0:print(f"loss: {loss.item():>7f} [batch:{batch_size_num}]")batch_size_num += 1
6.2 评估与模型保存
best_acc = 0def Test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= size# 保存最佳模型if correct > best_acc:best_acc = correcttorch.save(model.state_dict(), "best_model.pth")print(f"\n测试结果: \n 准确率:{(100*correct):.2f}%, 平均损失:{test_loss:.4f}")

关键改进

  1. 增加全局变量best_acc跟踪最佳准确率
  2. 实现两种模型保存方式:(1)只保存模型参数(state_dict)(2)保存整个模型
  3. 更详细的测试结果输出

七、完整训练流程

# 初始化
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练循环
epochs = 10
for t in range(epochs):print(f"Epoch {t+1}\n{'-'*20}")train(train_dataloader, model, loss_fn, optimizer)# 最终评估
Test(test_dataloader, model, loss_fn)

八、模型保存与加载

8.1 保存模型
# 方法1:只保存参数
torch.save(model.state_dict(), "model_params.pth")# 方法2:保存完整模型
torch.save(model, "full_model.pt")
8.2 加载模型
# 方法1对应加载方式
model = CNN().to(device)
model.load_state_dict(torch.load("model_params.pth"))# 方法2对应加载方式
model = torch.load("full_model.pt").to(device)

九、优化建议

  1. 数据增强:添加更多变换提高模型泛化能力
  2. 学习率调度:使用torch.optim.lr_scheduler动态调整学习率

http://www.xdnf.cn/news/1430461.html

相关文章:

  • 前缀和之距离和
  • 架构设计:AIGC 新规下 UGC 平台内容审核防火墙的构建
  • 【XR技术概念科普】什么是注视点渲染(Foveated Rendering)?为什么Vision Pro离不开它?
  • A股大盘数据-20250902分析
  • 深入浅出 RabbitMQ-消息可靠性投递
  • 学习日记-SpringMVC-day48-9.2
  • WPF应用程序资源和样式的使用示例
  • 洗衣店小程序的设计与实现
  • 深度学习篇---DenseNet网络结构
  • gitlab中回退代码,CI / CD 联系运维同事处理
  • VR森林经营模拟体验带动旅游经济发展
  • Time-MOE 音频序列分类任务
  • 【C++框架#2】gflags 和 gtest 安装使用
  • Redis 的跳跃表:像商场多层导航系统一样的有序结构
  • 疯狂星期四文案网第58天运营日记
  • 大模型微调数据准备全指南:清洗、标注与高质量训练集构造实战
  • 科研界“外挂”诞生了:科学多模态模型Intern-S1-mini开源
  • 我的项目我做主:Focalboard+cpolar让团队协作摆脱平台依赖
  • 大数据毕业设计选题推荐-基于大数据的电脑硬件数据分析系统-Hadoop-Spark-数据可视化-BigData
  • 临时邮箱地址获取服务器邮件工作流程与实现
  • playwright+python 实现图片对比
  • 【代码里的英雄传】Dubbo 的一生:一位分布式勇士的传奇旅程
  • 依托深兰科技AI技术生态,深兰教育携手沪上高校企业启动就业科创营
  • 高性能接口实现方案
  • 【微服务】-Gson反序列化泛型类型踩坑指南:如何正确处理Result<T>类型
  • MTK Linux DRM分析(三十)- MTK mtk_dsi.c(Part.2)
  • AI零售创业公司:零眸智能
  • PHP操作LibreOffice将替换变量后的word文件转换为PDF文件
  • ffmpeg 安装
  • C#基础(⑤ProcessStartInfo类和Process类)