当前位置: 首页 > backend >正文

食品分类案例

项目任务:导入食品图片,通过模型识别出食品种类。

代码实现:

import os
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch import nn
def train_test(root,wj):file_txt =open(wj+'.txt','w',encoding='utf-8')path =os.path.join(root,wj)for roots,wjj,files in os.walk(root):if len(wjj) != 0:wjs = wjjelse:new_wj= roots.split('\\')for file in files:path_1 = os.path.join(roots, file)print(path_1)file_txt.write(path_1 + ' ' + str(wjs.index(new_wj[-1])) + '\n')file_txt.close()
root1 = r'E:\pythonProject3\深度学习\卷积神经网络\食物分类\food_dataset2\food_dataset2\train'
root2 = r'E:\pythonProject3\深度学习\卷积神经网络\食物分类\food_dataset2\food_dataset2\test'
train_wj ='train1'
test_wj = 'test1'
train_test(root1,train_wj)
train_test(root2,test_wj)
data_transforms = {'train': transforms.Compose([transforms.Resize([256, 256]),transforms.RandomRotation(45),transforms.CenterCrop(256),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),'valid': transforms.Compose([transforms.Resize([256, 256]),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}# 3. 修复:自定义数据集类(标签预处理+通道统一+范围校验)
class food_dataset(Dataset):def __init__(self, file_path, transform=None, num_classes=20):  # 统一10类(如需20类,后续同步改)self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformself.num_classes = num_classes  # 类别数与模型保持一致with open(self.file_path, encoding='utf-8') as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label_str in samples:# 关键1:标签从字符串转整数(原代码未转,导致后续类型混乱)label = int(label_str)# 关键2:标签范围校验(确保在0~num_classes-1,避免超出模型输出范围)if label < 0 or label >= self.num_classes:print(f"警告:原始标签{label}超出{self.num_classes}类范围(0~{self.num_classes - 1}),已强制设为0")label = 0self.imgs.append(img_path)self.labels.append(label)  # 存储整数标签# 打印加载结果,验证标签范围是否正确print(f"加载{file_path}完成:")print(f"  标签范围:{min(self.labels)} ~ {max(self.labels)}")print(f"  模型预期范围:0 ~ {self.num_classes - 1}")print(f"  总图片数:{len(self.imgs)}\n")def __len__(self):return len(self.imgs)def __getitem__(self, idx):# 关键3:强制图片为RGB通道(避免单通道/四通道图,与卷积层输入3通道匹配)image = Image.open(self.imgs[idx]).convert('RGB')if self.transform:image = self.transform(image)# 标签转torch.long(CrossEntropyLoss强制要求)label = self.labels[idx]label_tensor = torch.from_numpy(np.array(label, dtype=np.int64))return image, label_tensor# 4. 修复:统一类别数(模型、数据集实例化保持一致)
num_classes = 20  # 如需改为20类,需同步修改3处:1.数据集num_classes 2.模型num_classes 3.标签范围(0~19)
training_data = food_dataset(file_path=r'E:\pythonProject3\深度学习\卷积神经网络\食物分类\train.txt',transform=data_transforms['train'],num_classes=num_classes  # 与模型类别数一致
)
test_data = food_dataset(file_path=r'E:\pythonProject3\深度学习\卷积神经网络\食物分类\test.txt',transform=data_transforms['valid'],num_classes=num_classes  # 与模型类别数一致
)# 5. 数据加载器(batch_size=64,如需降低显存可改为32)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)# 6. 设备配置(不变)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device\n")# 7. 修复:模型类别数与数据集统一
class CNN(nn.Module):def __init__(self, num_classes=20):  # 与数据集num_classes一致super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),  # 3通道匹配RGB图nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 256→128(尺寸减半))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 128→64(尺寸减半))self.conv3 = nn.Sequential(nn.Conv2d(64, 128, 5, 1, 2),nn.ReLU(),  # 保持64尺寸)# 关键4:全连接层输入维度正确(128通道 × 64高 × 64宽,基于256输入的两次池化)self.out = nn.Linear(128 * 64 * 64, num_classes)  # 输出维度=类别数def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # 展平特征图(batch, 128*64*64)output = self.out(x)return output# 8. 实例化模型(确保类别数与数据集一致)
model = CNN(num_classes=num_classes).to(device)
print("模型结构:")
print(model, "\n")# 9. 训练/测试函数(不变,修复打印笔误Accuray→Accuracy)
def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 打印损失(每100批次)loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss = 0correct = 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_pj_loss = test_loss / num_batchestest_acy = correct / size * 100print(f"Avg loss: {test_pj_loss:>7f} \n Accuracy: {test_acy:>5.2f}%")  # 修复笔误Accuray→Accuracy# 10. 训练配置与执行(不变)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 合适的学习率epochs = 50
for j in range(epochs):print(f"===== Epoch {j + 1}/{epochs} =====")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
def predict_single_image(image_path, model_path, transform, device, class_names):"""单张图片分类预测:param image_path: 待预测图片路径(如"C:/test.jpg"):param model_path: 训练好的模型权重路径:param transform: 图片预处理(与valid集一致):param device: 运行设备:param class_names: 类别名称列表(与标签0~9顺序完全对应):return: 预测类别、置信度(百分比)、原始图片(可选显示)"""# 加载模型(与训练时结构一致)model = CNN(num_classes=len(class_names)).to(device)model.load_state_dict(torch.load(model_path, map_location=device))model.eval()  # 切换为推理模式(关闭训练层)# 预处理图片try:# 强制RGB通道,避免图片格式错误image = Image.open(image_path).convert("RGB")image_tensor = transform(image)  # 复用valid集的预处理image_tensor = image_tensor.unsqueeze(0).to(device)  # 增加batch维度(1,3,256,256)except Exception as e:print(f"\n图片处理失败!错误:{e}")return None, None, None# 模型推理(关闭梯度计算)with torch.no_grad():output = model(image_tensor)  # 模型输出:(1,10)pred_prob = torch.softmax(output, dim=1)  # 转为概率(0~1)pred_idx = torch.argmax(pred_prob, dim=1).item()  # 取概率最大的标签索引pred_conf = pred_prob[0][pred_idx].item() * 100  # 置信度(百分比)# 映射到类别名称pred_class = class_names[pred_idx]return pred_class, pred_conf, image# -------------------------- 10. 执行预测(用户需修改2处!) --------------------------
if __name__ == "__main__":# 【用户必须修改1】类别名称列表:与训练时的标签0~9顺序完全一致!# 示例:若标签0=苹果、1=汉堡...,则按此顺序填写(替换为你的实际类别)FOOD_CLASSES = ["八宝粥", "哈密瓜", "圣女果", "巴旦木", "板栗","汉堡", "火龙果", "炸鸡", "瓜子", "生肉", "白萝卜", "胡萝卜", "草莓", "菠萝", "薯条","蛋", "蛋挞", "青菜", "骨肉相连", "鸡翅"]# 【用户必须修改2】待预测图片的路径(替换为你的图片路径,如桌面图片)TEST_IMAGE_PATH = r"E:\pythonProject3\深度学习\卷积神经网络\食物分类\food_dataset\test\八宝粥\img_八宝粥罐_81.jpeg"  # 示例路径pred_class, pred_conf, raw_image = predict_single_image(image_path=TEST_IMAGE_PATH,model_path="food_cnn_model.pth",  # 模型权重路径(与保存路径一致)transform=data_transforms['valid'],  # 与valid集预处理完全一致device=device,class_names=FOOD_CLASSES)# 打印预测结果if pred_class and pred_conf:print(f"图片的类型:八宝粥")print(f"预测结果:{pred_class}")else:print("\n预测失败,请检查图片路径或格式!")

这段代码是基于 PyTorch 实现的 20 类食物图像分类完整流程,涵盖 “数据准备→数据加载→模型构建→模型训练→单图预测” 全链路,且包含多处关键错误修复。

代码解析:

1. 数据准备:train_test函数(生成路径 - 标签映射文件)

功能:遍历训练 / 测试数据集文件夹,生成 “图像路径 + 类别标签” 的 txt 文件(如train1.txt),为后续 Dataset 加载数据提供索引。
核心逻辑

输入root(数据集根目录,如train文件夹)和wj(输出 txt 文件名前缀);

通过os.walk遍历文件夹:

先获取所有子文件夹(即食物类别,如 “八宝粥”“汉堡”),存入wjs

再遍历每个类别下的图像文件,拼接图像完整路径;

wjs.index(类别名)给每个类别分配唯一整数标签(如 “八宝粥” 是wjs[0]则标签为 0);

将 “路径 + 标签” 写入 txt 文件(每行格式:E:\...\img.jpg 0)。

示例输出train1.txt中会包含类似内容,为后续加载数据提供 “图像位置 + 类别” 的映射。

2. 数据预处理:data_transforms(图像增强与标准化)

功能:定义训练集(train)和验证 / 测试集(valid)的图像预处理流程,解决 “图像尺寸不一致”“数据量不足导致过拟合” 问题。
两类预处理的差异(训练集需数据增强,测试集仅基础处理):

类型预处理操作操作目的
train(训练集)1. Resize ([256,256]):统一缩放为 256×256
2. RandomRotation (45):随机旋转 0-45°
3. CenterCrop (256):中心裁剪(旋转后补边再裁回原尺寸)
4. RandomHorizontal/VerticalFlip:随机水平 / 垂直翻转(p=0.5 概率)
5. ColorJitter:调整亮度、对比度等(增加颜色多样性)
6. RandomGrayscale:随机转为灰度图(p=0.1)
7. ToTensor ():转为 PyTorch 张量(HWC→CHW,值归一化到 [0,1])
8. Normalize:用 ImageNet 均值方差标准化([0.485,0.456,0.406],[0.229,0.224,0.225])
1. 统一尺寸:满足模型输入要求
2. 数据增强:增加训练样本多样性,防止模型过拟合
valid(测试集)1. Resize([256,256])
2. ToTensor()
3. Normalize(同训练集)
1. 仅基础处理:不引入随机增强,保证测试结果稳定
2. 标准化需与训练集一致:避免数据分布差异影响推理
3. 数据加载:food_dataset类(自定义 Dataset)

功能:继承 PyTorch 的Dataset类,实现 “按 txt 索引加载图像 + 标签”,并修复 3 个核心错误(原代码潜在问题)。
核心修复与逻辑

初始化方法__init__

关键 1:标签类型转换:将 txt 中读取的标签字符串(如"0")转为整数(int(label_str)),避免后续计算时 “字符串与张量不兼容” 错误;

关键 2:标签范围校验:若标签超出0~num_classes-1(如 20 类时超出 0~19),强制设为 0 并报警,防止模型输出维度不匹配(模型输出是 20 维,标签不能是 20);

存储索引:将图像路径存入self.imgs,整数标签存入self.labels,并打印加载结果(标签范围、图片数量),方便验证数据正确性。

__getitem__方法(按索引取数据):

关键 3:强制 RGB 通道:用Image.open(...).convert('RGB')将图像统一转为 3 通道(避免单通道灰度图、4 通道 RGBA 图与卷积层输入(3 通道)不兼容);

预处理与标签格式:用self.transform处理图像,标签转为torch.long类型(PyTorch 的CrossEntropyLoss强制要求标签为long型);

返回 “处理后图像张量 + 标签张量”,供 DataLoader 调用。

__len__方法:返回数据集总样本数(len(self.imgs)),DataLoader 依赖此确定批次数量。

4. 数据加载器:DataLoader

功能:将food_dataset生成的数据集包装为 “可批量迭代” 的加载器,方便训练时按批次喂给模型。
参数说明

training_data/test_data:自定义的 Dataset 实例;

batch_size=64:每批加载 64 张图像(若 GPU 显存不足,可改为 32 或 16);

shuffle=True/False:训练集shuffle=True(每轮训练前打乱数据,避免模型记忆顺序),测试集shuffle=False(保持数据顺序,方便后续分析错误样本)。

5. 设备配置:自动选择计算设备
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

功能:按优先级自动选择计算设备,适配不同硬件:

优先用cuda(NVIDIA 显卡,训练速度最快);

其次用mps(Apple Silicon 芯片,如 M1/M2,苹果设备的 GPU 加速);

最后用cpu(无 GPU 时的备选,速度慢)。

6. 模型构建:CNN类(卷积神经网络)

功能:定义 20 类分类的 CNN 模型,通过 “卷积提取特征→全连接分类” 完成图像到类别的映射。
网络结构与尺寸变化(输入图像:3×256×256):

网络层操作细节输出特征图尺寸作用
conv11. Conv2d(3→16, 5×5, padding=2)
2. ReLU(激活函数)
3. MaxPool2d (2×2)(步长 2)
16×128×128第一次卷积:提取低级特征(如边缘、纹理),池化缩小尺寸(减少计算量)
conv21. Conv2d(16→32, 5×5, padding=2)
2. ReLU
3. Conv2d(32→64, 5×5, padding=2)
4. ReLU
5. MaxPool2d(2×2)
64×64×64第二次卷积:提取中级特征(如局部形状),再次池化缩小尺寸
conv31. Conv2d(64→128, 5×5, padding=2)
2. ReLU
128×64×64第三次卷积:提取高级特征(如食物的整体轮廓),不池化(保留尺寸)
out(全连接层)Linear(128×64×64 → 20)20(向量)将 128×64×64 的特征图展平为 1 维向量(1286464=524,288),输出 20 个类别的 “得分”

关键设计

卷积层用padding=2:保证卷积后尺寸不变(公式:输出尺寸=输入尺寸 - 核尺寸 + 2*padding + 1,5×5 核时256-5+4+1=256);

池化层用MaxPool2d(2):尺寸减半,通道数不变,既减少计算量,又增强特征鲁棒性;

num_classes=20:与数据集类别数一致,确保模型输出维度与标签匹配。

7. 模型训练与测试:train/test函数

功能:实现模型的迭代训练(更新参数)和测试(评估性能),核心是 “损失计算→反向传播→参数优化”。

(1)训练函数 train(dataloader, model, loss_fn, optimizer)

流程

model.train():将模型设为训练模式(启用 Dropout、BatchNorm 等训练专属层,若后续添加);

遍历dataloader的每一批数据(X:图像批次,y:标签批次):数据移到device(GPU/CPU);pred = model(X):模型输出预测得分(64×20,每行为 1 张图的 20 类得分);

loss = loss_fn(pred, y):计算损失(用CrossEntropyLoss,自动将得分与标签对比,适合多分类);

反向传播:optimizer.zero_grad()(清空上一轮梯度)→ loss.backward()(计算梯度)→ optimizer.step()(用梯度更新参数);

每 100 批次打印损失:监控训练进度,若损失不下降则需调整学习率或 batch_size。

(2)测试函数 test(dataloader, model, loss_fn)

流程

  1. model.eval():将模型设为评估模式(关闭 Dropout、固定 BatchNorm 参数,保证推理稳定);
  2. with torch.no_grad():关闭梯度计算(测试时无需更新参数,减少内存占用);
  3. 遍历测试集批次:计算test_loss(累加每批损失,最后求平均);计算correct(预测类别 = 真实类别(pred.argmax(1) == y)的样本数,累加后求准确率);
  4. 打印 “平均损失” 和 “准确率”:评估模型泛化能力(准确率越高,模型分类效果越好);
  5. 修复笔误:将原代码的Accuray改为Accuracy,避免打印错误。
8. 训练执行:初始化参数 + 迭代训练

核心代码

loss_fn = nn.CrossEntropyLoss()  # 多分类专用损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam优化器(学习率0.001,常用且稳定)
epochs = 50  # 训练轮次(50轮足够初步收敛,可根据准确率调整)for j in range(epochs):print(f"===== Epoch {j + 1}/{epochs} =====")train(train_dataloader, model, loss_fn, optimizer)  # 每轮先训练test(test_dataloader, model, loss_fn)  # 再测试评估

逻辑:每轮训练后立即测试,观察模型在 “未见过的测试集” 上的性能,若测试准确率不再提升,可提前停止训练(防止过拟合)。

9. 单图预测:predict_single_image函数(模型推理)

功能:加载训练好的模型权重,对单张新图片进行分类,返回 “预测类别 + 置信度 + 原始图片”。
流程

  1. 加载模型

    • 实例化CNN(类别数 =len(class_names),即 20),移到device
    • model.load_state_dict(torch.load(model_path, map_location=device)):加载训练好的权重文件(如food_cnn_model.pth,需与训练时保存路径一致);
    • model.eval():设为评估模式。
  2. 预处理图片

    • Image.open(image_path).convert("RGB")读取并强制 RGB 通道;
    • data_transforms['valid']预处理(与测试集一致,保证数据分布相同);
    • image_tensor.unsqueeze(0):增加 batch 维度(模型输入需是 “批次 × 通道 × 高 × 宽”,单图时为 1×3×256×256)。
  3. 模型推理

    • with torch.no_grad():关闭梯度计算;
    • output = model(image_tensor):输出预测得分;
    • pred_prob = torch.softmax(output, dim=1):将得分转为概率(0~1,所有类别概率和为 1);
    • pred_idx = torch.argmax(pred_prob, dim=1).item():取概率最大的类别索引;
    • pred_conf = pred_prob[0][pred_idx].item() * 100:计算置信度(百分比,如 95.2%)。
  4. 返回结果:映射索引到类别名称(如pred_idx=0→“八宝粥”),返回预测类别、置信度和原始图片。

10. 主函数:执行单图预测(用户需修改的关键部分)

核心代码

if __name__ == "__main__":# 【必须修改1】类别名称列表:与训练时标签0~19顺序完全一致(否则预测类别错乱)FOOD_CLASSES = ["八宝粥", "哈密瓜", "圣女果", "巴旦木", "板栗","汉堡", "火龙果", "炸鸡", "瓜子", "生肉", "白萝卜", "胡萝卜", "草莓", "菠萝", "薯条","蛋", "蛋挞", "青菜", "骨肉相连", "鸡翅"]# 【必须修改2】待预测图片路径(需是真实存在的图片路径,格式支持jpg/png等)TEST_IMAGE_PATH = r"E:\...\img_八宝粥罐_81.jpeg"# 调用预测函数pred_class, pred_conf, raw_image = predict_single_image(...)# 打印结果if pred_class and pred_conf:print(f"图片的类型:八宝粥")  # 此处为硬编码,建议改为动态打印(如f"原始标注类型:...",若有标注)print(f"预测结果:{pred_class}")else:print("预测失败,检查路径/格式!")

注意事项

  • 硬编码问题:print(f"图片的类型:八宝粥")是固定值,若测试其他图片会不准确,建议改为 “若有原始标注则读取标注,无则不打印”;
  • 模型权重路径:model_path="food_cnn_model.pth"需确保该文件存在(训练时需添加torch.save(model.state_dict(), "food_cnn_model.pth")保存权重,原代码未写,需补充)。

三、核心修复点总结(避免原代码错误)

  1. 标签处理:字符串转整数 + 范围校验,防止类型错误和维度不匹配;
  2. 图像通道:强制 RGB,避免卷积层输入通道不兼容;
  3. 模型维度:全连接层输入128×64×64与特征图尺寸匹配,输出维度 = 类别数;
  4. 损失函数标签类型:标签转为torch.long,符合CrossEntropyLoss要求;
  5. 笔误修复:AccurayAccuracy,避免打印错误。

四、使用建议

  1. 补充模型保存代码:在训练结束后添加torch.save(model.state_dict(), "food_cnn_model.pth"),否则预测时无权重文件;
  2. 调整超参数:若训练过拟合(训练准确率高、测试准确率低),可减少epochs、降低batch_size或增加数据增强;若欠拟合(准确率低),可加深网络或调大学习率;
  3. 验证类别顺序:FOOD_CLASSES必须与train_test函数分配的类别索引一致(可通过train1.txt的标签对应验证);
  4. 处理预测硬编码:将print(f"图片的类型:八宝粥")改为动态逻辑,如读取图片所在文件夹名称作为原始类别(若测试集按类别分文件夹)。
http://www.xdnf.cn/news/19637.html

相关文章:

  • 码住!辉芒微MCU型号规则详细解析
  • Kafka 架构详解
  • 动子注册操作【2025.9.2学习记录】
  • MVP架构深层剖析-从六大设计原则的实现角度到用依赖注入深度解耦
  • Elasticsearch 核心知识与常见问题解析
  • MCU上跑AI—实时目标检测算法探索
  • 【 HarmonyOS 6 】HarmonyOS智能体开发实战:Function组件和智能体创建
  • 空间不足将docker挂载到其他位置
  • 03_网关ip和端口映射(路由器转发)操作和原理
  • 梯度消失问题:深度学习中的「记忆衰退」困境与解决方案
  • React 学习笔记4 Diffing/脚手架
  • 2025了,你知道electron-vite吗?
  • 网络原理——HTTP/HTTPS
  • ImageMagick命令行图片工具:批量实现格式转换与压缩,支持水印添加及GIF动态图合成
  • 2条命令,5秒安装,1秒启动!Vite项目保姆级上手指南
  • 鸿蒙NEXT界面交互全解析:弹出框、菜单、气泡提示与模态页面的实战指南
  • 开源的聚合支付系统源码/易支付系统 /三方支付系统
  • Erlang 利用 recon 排查热点进程
  • 人工智能之数学基础:分布函数对随机变量的概率分布情况进行刻画
  • 微信小程序 navigateTo 栈超过多层后会失效
  • 在 Delphi 5 中获取 Word 文档页数的方法
  • 小程序蓝牙低功耗(BLE)外围设备开发指南
  • 365 天技术创作手记:从一行代码到四万同行者的相遇
  • C++多线程编程:std::thread, std::async, std::future
  • Jenkins Pipeline 语法
  • 第 12 篇:网格边界安全 - Egress Gateway 与最佳实践
  • python中的zip() 函数介绍及使用说明
  • 基于Spark的新冠肺炎疫情实时监控系统_django+spider
  • HTML第三课:特殊元素
  • 跨境电商账号风控核心:IP纯净度与浏览器指纹的防护策略