当前位置: 首页 > news >正文

从 0 到 1 实现 PyTorch 食物图像分类:核心知识点与完整实

食物图像分类是计算机视觉的经典任务之一,其核心是让机器 “看懂” 图像中的食物类别。随着深度学习的发展,卷积神经网络(CNN)凭借强大的特征提取能力,成为图像分类的主流方案。本文将基于 PyTorch 框架,从代码实战出发,拆解食物图像分类项目中的核心知识点,包括环境搭建、数据预处理、数据集构建、CNN 模型设计、模型训练与测试、单图预测等,带大家从零搭建一个能识别 20 类食物的分类系统。

# 导入必要的库
import torch  # PyTorch核心库,用于构建和训练神经网络
from torch import nn  # 神经网络模块,包含各种层和损失函数
from torch.utils.data import Dataset, DataLoader  # 数据集和数据加载器,用于数据处理
import numpy as np  # 数值计算库,可用于数据预处理等
from PIL import Image  # 图像处理库,用于读取和处理图像
from torchvision import transforms  # 图像转换工具,用于数据增强和预处理
import os  # 操作系统接口,用于文件路径处理等# 定义数据转换策略:训练集使用数据增强,验证集/测试集保持一致的基础转换
data_transforms = {'train':  # 训练集转换(包含数据增强,增加样本多样性)transforms.Compose([transforms.Resize([300, 300]),  # 先将图像调整为300x300transforms.RandomRotation(45),  # 随机旋转(-45~45度),增强旋转不变性transforms.CenterCrop(256),  # 中心裁剪到256x256,去除旋转后的黑边transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻转transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),  # 随机调整亮度、对比度、饱和度和色调transforms.RandomGrayscale(p=0.1),  # 10%概率转为灰度图transforms.ToTensor(),  # 转为Tensor格式([C, H, W]),并将像素值归一化到[0,1]transforms.Normalize(  # 使用ImageNet的均值和标准差进行标准化[0.485, 0.456, 0.406],  # 均值(RGB三个通道)[0.229, 0.224, 0.225]  # 标准差(RGB三个通道))]),'valid':  # 验证集/测试集转换(无增强,保持数据一致性)transforms.Compose([transforms.Resize([256, 256]),  # 调整为256x256,与训练集裁剪后尺寸一致transforms.ToTensor(),  # 转为Tensor]),
}# -------------------------- 2. 自定义数据集类 --------------------------
class food_dataset(Dataset):"""自定义食物图像数据集类,继承自PyTorch的Dataset用于加载图像路径和对应标签,并进行预处理"""def __init__(self, file_path, transform=None):"""初始化数据集:param file_path: 存储图像路径和标签的文本文件路径:param transform: 图像转换函数(预处理/数据增强)"""self.file_path = file_path  # 文本文件路径self.transform = transform  # 转换函数self.imgs = []  # 存储所有图像路径self.labels = []  # 存储对应标签# 读取文件列表(每行格式:图片路径 数字标签)with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip()  # 去除首尾空格和换行符if not line:  # 跳过空行continue# 按空格分割路径和标签(假设格式严格,无多余空格)img_path, label = line.split(' ')self.imgs.append(img_path)self.labels.append(label)def __len__(self):"""返回数据集样本数量"""return len(self.imgs)def __getitem__(self, index):"""根据索引获取单个样本(图像和标签):param index: 样本索引:return: 处理后的图像张量和标签张量"""# 读取图片并强制转为RGB(避免灰度图导致的通道数不匹配问题)try:image = Image.open(self.imgs[index]).convert('RGB')  # 确保3通道输入except Exception as e:# 捕获读取错误,便于调试raise ValueError(f"读取图片 {self.imgs[index]} 失败:{e}")# 应用转换(预处理/数据增强)if self.transform:image = self.transform(image)# 处理标签:转为整数类型的张量(PyTorch分类任务要求标签为long类型)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label# 加载数据集
# 注意:需确保train.txt和test.txt文件存在,每行格式为「图片路径 数字标签」
try:# 加载训练集(使用训练集转换)training_data = food_dataset(file_path='./train.txt', transform=data_transforms['train'])# 加载测试集(使用验证集转换)test_data = food_dataset(file_path='./test.txt', transform=data_transforms['valid'])
except FileNotFoundError:# 捕获文件不存在错误,提示用户raise FileNotFoundError("请确保 train.txt 和 test.txt 文件在当前目录下")# 创建数据加载器(批量加载数据,支持打乱和多进程)
train_dataloader = DataLoader(training_data,batch_size=8,  # 批大小:每次加载8张图片shuffle=True  # 训练时打乱数据顺序,增强训练效果
)
test_dataloader = DataLoader(test_data,batch_size=8,  # 测试时也用相同批大小shuffle=True  # 测试时打乱不影响结果,主要便于观察不同样本
)# 设备配置:优先使用GPU(cuda),其次是Apple M系列芯片(mps),最后是CPU
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')  # 打印使用的设备# 定义CNN模型(卷积神经网络)
class CNN(nn.Module):"""自定义卷积神经网络模型,用于食物图像分类包含4个卷积块和1个全连接输出层"""def __init__(self):super().__init__()  # 调用父类nn.Module的初始化方法# 第一个卷积块:1次卷积 + ReLU激活 + 最大池化self.conv1 = nn.Sequential(# 卷积层:输入3通道(RGB),输出16通道,卷积核5x5,步长1,填充2(保持尺寸)nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),  # 激活函数,引入非线性nn.MaxPool2d(kernel_size=2),  # 最大池化:尺寸减半(256→128))# 第二个卷积块:2次卷积 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),  # 输入16通道,输出32通道nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),  # 输入32通道,输出32通道nn.ReLU(),nn.MaxPool2d(2),  # 尺寸减半(128→64))# 第三个卷积块:2次卷积 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 输入32通道,输出64通道nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),  # 输入64通道,输出128通道nn.ReLU(),nn.MaxPool2d(2),  # 尺寸减半(64→32))# 第四个卷积块:1次卷积 + ReLU(无池化,保持尺寸)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2),  # 输入128通道,输出128通道nn.ReLU(),  # 输出尺寸:32×32,通道数128)# 全连接输出层:将特征映射到20个类别(食物种类)# 输入尺寸计算:128通道 × 32高 × 32宽(经多次池化后的特征图尺寸)self.out = nn.Linear(128 * 32 * 32, 20)  # 20类:需与标签数量一致def forward(self, x):"""前向传播:定义数据在网络中的流动路径:param x: 输入张量,形状为[batch_size, 3, 256, 256]:return: 输出张量,形状为[batch_size, 20](各类别的预测分数)"""x = self.conv1(x)  # 经第一个卷积块处理x = self.conv2(x)  # 经第二个卷积块处理x = self.conv3(x)  # 经第三个卷积块处理x = self.conv4(x)  # 经第四个卷积块处理x = x.view(x.size(0), -1)  # 展平特征图:[batch_size, 128*32*32]output = self.out(x)  # 经全连接层输出预测结果return output# -------------------------- 训练与测试函数 --------------------------
def train(dataloader, model, loss_fn, optimizer):"""训练模型的函数:param dataloader: 训练数据集加载器:param model: 待训练的模型:param loss_fn: 损失函数(用于计算预测误差):param optimizer: 优化器(用于更新模型参数)"""model.train()  # 开启训练模式(启用Dropout、BatchNorm等训练特定行为)batch_size_num = 1  # 记录当前批次编号for X, y in dataloader:# 将数据移动到指定设备(GPU/CPU)X, y = X.to(device), y.to(device)# 前向传播:计算模型预测结果pred = model(X)# 计算损失(预测值与真实标签的差距)loss = loss_fn(pred, y)# 反向传播与参数更新optimizer.zero_grad()  # 清空上一轮的梯度(避免梯度累积)loss.backward()  # 反向传播计算梯度optimizer.step()  # 根据梯度更新模型参数# 打印损失(每2个batch打印一次,便于监控训练过程)loss_val = loss.item()  # 获取损失的标量值if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f}  [batch: {batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):model.eval()  # 开启评估模式(关闭Dropout、固定BatchNorm参数等)size = len(dataloader.dataset)  # 测试集总样本数num_batches = len(dataloader)  # 测试集批次数test_loss, correct = 0, 0  # 总损失和正确预测数# 关闭梯度计算(测试时不需要更新参数,节省计算资源)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)  # 数据移至设备pred = model(X)  # 预测test_loss += loss_fn(pred, y).item()  # 累加损失# 统计正确预测数:取预测概率最大的类别与真实标签比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batches  # 平均损失correct /= size  # 准确率print(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")# -------------------------- 单张图片预测函数 --------------------------
def predict_single_image(image_path, model, transform, device, label_map):"""对单张图片进行预测:param image_path: 图片路径:param model: 训练好的模型:param transform: 图像预处理函数(与测试集一致):param device: 计算设备:param label_map: 标签映射字典(数字标签→食物名称):return: 预测的食物名称"""# 读取并预处理图片(与测试集预处理一致)image = Image.open(image_path).convert('RGB')  # 确保3通道image = transform(image)  # 应用预处理(Resize和ToTensor)# 增加batch维度(模型要求输入格式为[batch, C, H, W],这里batch=1)image = image.unsqueeze(0).to(device)# 模型预测model.eval()  # 开启评估模式with torch.no_grad():  # 关闭梯度计算pred_logits = model(image)  # 得到预测分数(logits)# 取概率最大的类别标签(argmax(1)按行取最大值索引)pred_label = pred_logits.argmax(1).item()# 映射为食物名称if pred_label not in label_map:raise KeyError(f"预测标签 {pred_label} 不在标签映射字典中")return label_map[pred_label]# -------------------------- 主程序 --------------------------
if __name__ == "__main__":# 初始化模型、损失函数、优化器model = CNN().to(device)  # 创建模型并移至设备loss_fn = nn.CrossEntropyLoss()  # 多分类问题常用交叉熵损失# Adam优化器:自适应学习率,训练效果较好,学习率0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型(100轮)epochs = 100for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n----------------------------")train(train_dataloader, model, loss_fn, optimizer)print("Training Done!")# 测试模型在测试集上的性能test(test_dataloader, model, loss_fn)# 定义标签映射字典:数字标签→食物名称# 需与数据集的标签完全对应(顺序和数量一致)label_to_food = {0: "八宝粥", 1: "巴旦木", 2: "白萝卜", 3: "板栗", 4: "菠萝",5: "草莓", 6: "蛋", 7: "蛋挞", 8: "骨肉相连", 9: "瓜子",10: "哈密瓜", 11: "汉堡", 12: "胡萝卜", 13: "火龙果", 14: "鸡翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯条", 19: "炸鸡"}# 输入图片路径并预测image_path = input("请输入图片路径:")  # 用户输入待预测图片路径true_food = input("请输入该图片的真实食物名称:")  # 用户输入真实标签(用于对比)# 执行预测并输出结果predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food)# 输出对比结果print("\n" + "-" * 50)print(f"预测结果:{predicted_food}")print(f"真实结果:{true_food}")print(f"判断:{'预测正确' if predicted_food == true_food else '预测错误'}")print("-" * 50)

二、数据预处理:让数据 “适配” 模型

在深度学习中,数据预处理的质量直接影响模型性能。原始图像可能存在尺寸不一、像素值范围差异大、样本数量不足等问题,需通过预处理将其转化为模型可接受的格式,并通过数据增强提升模型泛化能力。

本项目的预处理逻辑集中在data_transforms字典中,分 “训练集” 和 “验证集 / 测试集” 两种策略,我们逐一拆解其设计思路。

2.1 为什么要区分训练集与验证集预处理?

  • 训练集:需要通过 “数据增强” 增加样本多样性,避免模型过拟合(即模型只记住训练样本,对新样本识别能力差)。
  • 验证集 / 测试集:需保持数据的 “真实性”,仅进行基础预处理(如 Resize、ToTensor),确保评估结果能反映模型的实际泛化能力。

2.2 训练集数据增强:每一步的作用与原理

训练集的预处理链为:Resize → RandomRotation → CenterCrop → RandomHorizontalFlip → RandomVerticalFlip → ColorJitter → RandomGrayscale → ToTensor → Normalize,我们逐个解析:

(1)Resize ([300, 300]):统一初始尺寸

将所有图像调整为 300×300 像素。为什么不直接调整为最终的 256×256?因为后续会进行旋转和裁剪,预留一定尺寸可避免旋转后出现黑边。

(2)RandomRotation (45):随机旋转

随机将图像旋转 - 45°~45°。食物在拍摄时可能有不同角度(如躺着的汉堡、竖放的胡萝卜),旋转增强能让模型对角度不敏感,提升鲁棒性。

(3)CenterCrop (256):中心裁剪

将旋转后的图像从中心裁剪为 256×256。旋转会导致图像边缘出现黑边,裁剪可去除黑边,同时将图像尺寸统一为模型输入尺寸(256×256)。

(4)RandomHorizontalFlip (p=0.5) & RandomVerticalFlip (p=0.5):随机翻转
  • 水平翻转(50% 概率):模拟 “左右镜像” 的食物(如翻转后的草莓外观不变)。
  • 垂直翻转(50% 概率):模拟 “上下颠倒” 的场景(如掉落的薯条)。
    翻转操作不改变食物的核心特征,但能增加样本多样性,且计算成本低。
(5)ColorJitter (0.1, 0.1, 0.1, 0.1):随机颜色抖动

调整图像的亮度、对比度、饱和度、色调,各参数的取值范围为 0~1(0 表示不调整,1 表示最大调整幅度)。
食物图像的颜色易受光照影响(如白天和夜晚拍摄的青菜颜色不同),颜色抖动能让模型对光照变化不敏感。

(6)RandomGrayscale (p=0.1):随机灰度化

10% 概率将彩色图像转为灰度图。虽然食物的颜色是重要特征,但灰度化能迫使模型关注食物的形状、纹理等更本质的特征,避免过度依赖颜色信息(如红色的草莓和红色的圣女果,需通过形状区分)。

(7)ToTensor ():转为 Tensor 格式

将 PIL 图像(H×W×C,像素值 0~255)转为 PyTorch Tensor(C×H×W,像素值归一化到 0~1)。

  • 维度转换:模型要求输入为 “通道优先”(C×H×W),而 PIL 图像是 “高度优先”(H×W×C),需通过 ToTensor 调整。
  • 归一化:将像素值从 0~255 缩放到 0~1,避免大数值导致模型梯度爆炸。
(8)Normalize ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):标准化

使用 ImageNet 数据集的均值和标准差对 Tensor 进行标准化,公式为:
标准化后像素值 = (原始像素值 - 均值) / 标准差
为什么用 ImageNet 的参数?因为本项目后续可扩展为迁移学习(使用预训练模型),而预训练模型是在 ImageNet 上训练的,使用相同的标准化参数能让模型更快收敛。

transforms.Compose([transforms.Resize([300, 300]),  # 先将图像调整为300x300transforms.RandomRotation(45),  # 随机旋转(-45~45度),增强旋转不变性transforms.CenterCrop(256),  # 中心裁剪到256x256,去除旋转后的黑边transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻转transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),  # 随机调整亮度、对比度、饱和度和色调transforms.RandomGrayscale(p=0.1),  # 10%概率转为灰度图transforms.ToTensor(),  # 转为Tensor格式([C, H, W]),并将像素值归一化到[0,1]transforms.Normalize(  # 使用ImageNet的均值和标准差进行标准化[0.485, 0.456, 0.406],  # 均值(RGB三个通道)[0.229, 0.224, 0.225]  # 标准差(RGB三个通道))

2.3 验证集预处理

验证集的预处理链为:Resize([256, 256]) → ToTensor(),仅保留基础操作:

  • Resize ([256, 256]):直接将图像调整为 256×256,无需旋转(避免引入非真实样本)。
  • ToTensor ():与训练集一致,确保数据格式统一。
    (注:代码中验证集未做 Normalize,实际项目中建议与训练集保持一致,此处可根据需求调整)

三、自定义 Dataset:PyTorch 数据加载的核心

PyTorch 通过DatasetDataLoader实现数据加载,其中Dataset负责 “定义数据来源和格式”,DataLoader负责 “批量加载和并行处理”。本项目自定义了food_dataset类,用于加载食物图像和对应标签,我们详细解析其实现逻辑。

3.1 Dataset 的核心作用

Dataset是一个抽象类,要求子类必须实现三个方法:

  1. __init__:初始化数据集(读取文件列表、加载预处理函数)。
  2. __len__:返回数据集的总样本数。
  3. __getitem__:根据索引返回单个样本(图像 + 标签)。
    这三个方法确保了 PyTorch 能高效地迭代访问数据。

3.2 food_dataset 类逐方法解析

(1)init:初始化数据列表
def __init__(self, file_path, transform=None):self.file_path = file_path  # 存储图像路径和标签的txt文件路径self.transform = transform  # 预处理函数self.imgs = []  # 存储所有图像路径self.labels = []  # 存储对应标签# 读取txt文件,解析图像路径和标签with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip()  # 去除首尾空格和换行符if not line:  # 跳过空行(避免解析错误)continueimg_path, label = line.split(' ')  # 按空格分割路径和标签self.imgs.append(img_path)self.labels.append(label)

  • txt 文件格式要求:每行需包含 “图像路径” 和 “数字标签”,用空格分隔
    其中 “0” 对应 “八宝粥”,“1” 对应 “巴旦木”,需与后续label_to_food字典一致。
(2)len:返回样本总数
def __len__(self):return len(self.imgs)

简单直接,返回self.imgs的长度),DataLoader会通过该方法确定迭代次数。

(3)getitem:返回单个样本
def __getitem__(self, index):# 读取图像并强制转为RGB(避免灰度图通道数问题)try:image = Image.open(self.imgs[index]).convert('RGB')except Exception as e:raise ValueError(f"读取图片 {self.imgs[index]} 失败:{e}")# 应用预处理if self.transform:image = self.transform(image)# 处理标签:转为int64类型Tensor(PyTorch分类任务要求)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label

这是Dataset的核心方法,需重点关注三个细节:

  1. 强制 RGB 格式convert('RGB')确保所有图像都是 3 通道(避免部分灰度图是 1 通道,导致模型输入维度不匹配)。
  2. 异常处理try-except捕获图像读取错误(如路径错误、图像损坏),并明确提示错误位置,便于调试。
  3. 标签类型:将标签转为torch.int64(即 LongTensor),因为 PyTorch 的CrossEntropyLoss要求标签为 Long 类型。

3.3 如何准备自己的数据集?

  1. 收集图像:每个食物类别收集至少 100 张图像(样本越多,模型性能越好),建议按类别分文件夹存储
  2. 生成 txt 文件:编写脚本遍历图像文件夹,生成train.txttest.txt

    import os# 数据集根目录
    train_root = "./dataset/train"
    test_root = "./dataset/test"
    # 标签映射(与后续一致)
    label_to_food = {0: "八宝粥", 1: "巴旦木", ..., 19: "炸鸡"}
    # 反向映射:食物名称→数字标签
    food_to_label = {v: k for k, v in label_to_food.items()}# 生成train.txt
    with open("train.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(train_root):food_dir = os.path.join(train_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")# 生成test.txt(逻辑同上)
    with open("test.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(test_root):food_dir = os.path.join(test_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")
    
  3. 检查路径:确保 txt 文件中的图像路径与实际文件路径一致

四、DataLoader:批量加载与并行处理

Dataset定义了数据的 “来源”,而DataLoader则负责将数据 “批量加载” 到模型中,并支持并行处理,提升数据加载速度。

4.1 DataLoader 的核心参数解析

本项目的DataLoader初始化代码如下:

train_dataloader = DataLoader(training_data, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=8, shuffle=True)

核心参数含义:

  • batch_size:每次加载的样本数量(批大小),需根据 GPU 显存调整
  • shuffle:是否打乱数据顺序。训练集设为True,验证集可设为False,本项目测试集设为True是为了观察不同样本的预测效果。

4.2 DataLoader 与 Dataset 的协作流程

DataLoader的工作流程可概括为:

  1. 调用Dataset.__len__()获取总样本数,计算总批次数(总样本数 //batch_size)。
  2. shuffle=True,则在每个 epoch(训练轮次)开始前打乱样本索引。
  3. 对每个批次,根据索引调用Dataset.__getitem__()获取单个样本,组装成一个批次的 Tensor(形状为 [batch_size, C, H, W])。
  4. 将批次数据移动到指定设备(CUDA/MPS/CPU),供模型训练或测试。

4.3 数据加载到设备的逻辑

X, y = X.to(device), y.to(device)

其中device是通过以下代码确定的:

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

  • 为什么要移动设备? 模型和数据必须在同一设备上才能进行计算(如模型在 CUDA 上,数据也需在 CUDA 上),否则会报错。
  • 设备优先级:优先使用 CUDA(NVIDIA GPU),其次是 MPS(Apple M 系列),最后是 CPU

五、CNN 模型构建:从卷积到全连接的特征提取

卷积神经网络(CNN)是图像分类的核心,其通过 “卷积层提取局部特征→池化层降维→全连接层分类” 的流程,实现对图像的识别。本项目的 CNN 模型包含 4 个卷积块和 1 个全连接层,我们逐一解析其设计思路和尺寸计算。

5.1 CNN 的核心组件与作用

在解析代码前,先回顾 CNN 的三个核心组件:

  1. 卷积层(Conv2d):通过卷积核滑动提取图像的局部特征(如边缘、纹理、形状),输出 “特征图”(Feature Map)。
  2. 激活函数(ReLU):引入非线性,让模型能拟合复杂的特征关系(避免线性模型的表达能力不足)。
  3. 池化层(MaxPool2d):对特征图进行下采样,降低维度和计算量,同时增强模型对特征位置的鲁棒性。

5.2 模型代码逐块解析

模型定义代码如下,我们按 “卷积块 1→卷积块 2→卷积块 3→卷积块 4→全连接层” 的顺序解析:

class CNN(nn.Module):def __init__(self):super().__init__()# 卷积块1:1次卷积 + ReLU + 最大池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 池化后尺寸:256→128)# 卷积块2:2次卷积 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 尺寸:128→64)# 卷积块3:2次卷积 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 尺寸:64→32)# 卷积块4:1次卷积 + ReLU(无池化)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2),nn.ReLU(),  # 输出尺寸:32×32,通道数128)# 全连接层:映射到20类self.out = nn.Linear(128 * 32 * 32, 20)
池化层参数计算:尺寸减半

MaxPool2d(kernel_size=2)表示池化核大小为 2×2,步长默认等于核大小(即 2),因此输出尺寸为输入尺寸的 1/2:

  • conv1池化前尺寸:256×256 → 池化后:128×128
  • conv2池化前尺寸:128×128 → 池化后:64×64
  • conv3池化前尺寸:64×64 → 池化后:32×32
各卷积块的输出特征
  • conv1:输出特征图形状为 [batch_size, 16, 128, 128],提取的是图像的低级特征(如边缘、颜色块)。
  • conv2:输出形状为 [batch_size, 32, 64, 64],通过 2 次卷积提取更复杂的特征(如食物的局部轮廓)。
  • conv3:输出形状为 [batch_size, 128, 32, 32],通道数增加到 128,特征更抽象(如食物的结构特征)。
  • conv4:输出形状为 [batch_size, 128, 32, 32],无池化,进一步细化特征(避免池化导致的特征损失)。
(4)全连接层:从特征到分类

全连接层self.out的输入维度是128×32×32,这是由conv4的输出特征图形状决定的:

  • conv4输出:[batch_size, 128, 32, 32] → 展平后为 [batch_size, 128×32×32](展平操作在forward中通过x.view(x.size(0), -1)实现)。
  • 输出维度:20,对应 20 种食物类别,每个维度输出该类别的 “预测分数”(后续通过argmax取分数最高的类别作为预测结果)。

5.3 forward 方法:定义数据流动路径

def forward(self, x):x = self.conv1(x)  # 经卷积块1处理x = self.conv2(x)  # 经卷积块2处理x = self.conv3(x)  # 经卷积块3处理x = self.conv4(x)  # 经卷积块4处理x = x.view(x.size(0), -1)  # 展平:[batch_size, 128*32*32]output = self.out(x)  # 全连接层输出return output
  • 展平操作x.view(x.size(0), -1)将 4 维特征图(batch, C, H, W)转为 2 维张量(batch, C×H×W),因为全连接层仅接受 2 维输入。
  • 数据流动:输入图像([batch, 3, 256, 256])→ 卷积块 1→2→3→4 → 展平 → 全连接层 → 输出([batch, 20])。

5.4 模型初始化与设备移动

模型初始化代码如下:

model = CNN().to(device)
  • CNN()创建模型实例,to(device)将模型参数移动到指定设备(CUDA/MPS/CPU),确保模型和数据在同一设备上计算。

六、模型训练与测试:从损失下降到性能评估

模型构建完成后,需通过训练让模型 “学习” 食物特征,再通过测试评估模型的泛化能力。本项目定义了traintest两个函数,分别实现训练和测试逻辑。

6.1 训练函数:让模型 “学习”

训练函数的核心是 “前向传播计算损失→反向传播更新参数”,代码如下:

def train(dataloader, model, loss_fn, optimizer):model.train()  # 开启训练模式(启用Dropout、BatchNorm训练行为)batch_size_num = 1  # 批次编号,用于打印损失for X, y in dataloader:# 数据移动到设备X, y = X.to(device), y.to(device)# 1. 前向传播:计算预测结果pred = model(X)# 2. 计算损失:预测值与真实标签的差距loss = loss_fn(pred, y)# 3. 反向传播与参数更新optimizer.zero_grad()  # 清空上一轮梯度(避免累积)loss.backward()        # 反向传播计算梯度optimizer.step()       # 根据梯度更新模型参数# 打印损失(每2个批次打印一次)loss_val = loss.item()if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f}  [batch: {batch_size_num}]")batch_size_num += 1
(1)关键步骤解析
  1. model.train():开启训练模式,对含有 Dropout、BatchNorm 的模型至关重要:

    • Dropout:训练时随机 “关闭” 部分神经元,防止过拟合;测试时不关闭。
    • BatchNorm:训练时使用批次的均值和方差归一化;测试时使用训练阶段累积的均值和方差。
  2. 前向传播(Forward Pass)

    • pred = model(X):将批次数据输入模型,得到预测结果([batch, 20])。
    • loss = loss_fn(pred, y):计算损失,本项目使用CrossEntropyLoss(多分类任务的常用损失函数)。
  3. 反向传播(Backward Pass)与参数更新

    • optimizer.zero_grad():清空梯度。若不清空,梯度会累积到上一轮,导致参数更新错误。
    • loss.backward():根据损失计算各参数的梯度(
    • optimizer.step():根据梯度更新模型参数
(2)损失函数与优化器选择

本项目使用的损失函数和优化器如下:

loss_fn = nn.CrossEntropyLoss()  # 多分类交叉熵损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam优化器
  • CrossEntropyLoss:适用于多分类任务
  • Adam 优化器:自适应学习率优化器,收敛速度快,对学习率不敏感,是深度学习中最常用的优化器之一。lr=0.001是常用的初始学习率,可根据训练情况调整

6.2 测试函数:评估模型泛化能力

测试函数的核心是 “计算模型在测试集上的准确率和平均损失”,代码如下:

def test(dataloader, model, loss_fn):model.eval()  # 开启评估模式(关闭Dropout、固定BatchNorm)size = len(dataloader.dataset)  # 测试集总样本数num_batches = len(dataloader)    # 测试集批次数test_loss, correct = 0, 0        # 总损失和正确预测数# 关闭梯度计算(节省资源,避免参数更新)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累加损失test_loss += loss_fn(pred, y).item()# 累加正确预测数correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batchescorrect /= sizeprint(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")
(1)关键步骤解析
  1. model.eval():开启评估模式,与model.train()对应,确保模型在测试时的行为与训练时一致(如关闭 Dropout)。
  2. torch.no_grad():上下文管理器,关闭梯度计算。测试时无需更新参数,关闭梯度可大幅减少内存占用和计算时间。
  3. 准确率计算
    • pred.argmax(1):对每个样本,取预测分数最高的类别(维度 1 是类别维度,[batch, 20]→[batch, 1])。
    • (pred.argmax(1) == y):比较预测类别与真实标签,得到布尔张量(True = 正确,False = 错误)。
    • type(torch.float).sum().item():将布尔张量转为 float(True=1,False=0),求和得到正确预测数,再转为 Python 标量。
  4. 结果解读
    • Accuracy:准确率(正确预测数 / 总样本数),反映模型的整体识别能力,越高越好。
    • Avg Loss:平均损失,反映模型预测值与真实标签的平均差距,越低越好。

6.3 训练流程与轮次设置

训练流程代码如下:

# 训练轮次(epochs)
epochs = 100
for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Training Done!")# 测试模型
test(test_dataloader, model, loss_fn)
  • epochs(训练轮次):表示模型将遍历整个训练集的次数。本项目设为 100,可根据实际情况调整:
    • 若训练损失仍在下降,可增加 epochs;
    • 若训练损失下降但测试损失上升(过拟合),可减少 epochs 或加入早停机制。
  • 训练与测试顺序:每轮训练后可加入测试(如在train后调用test),便于监控模型是否过拟合;本项目在所有训练完成后测试,适用于快速验证。

七、单张图片预测:模型的实际应用

训练完成后,需将模型用于实际场景 —— 对单张食物图片进行分类。本项目定义了predict_single_image函数,实现从图像读取到类别输出的完整流程。

7.1 预测函数解析

def predict_single_image(image_path, model, transform, device, label_map):# 1. 读取并预处理图像(与测试集一致)image = Image.open(image_path).convert('RGB')  # 强制RGBimage = transform(image)  # 应用预处理(Resize + ToTensor)# 2. 增加batch维度(模型要求输入为[batch, C, H, W])image = image.unsqueeze(0).to(device)# 3. 模型预测model.eval()  # 开启评估模式with torch.no_grad():pred_logits = model(image)  # 预测分数(logits)pred_label = pred_logits.argmax(1).item()  # 取最高分数类别# 4. 映射为食物名称if pred_label not in label_map:raise KeyError(f"预测标签 {pred_label} 不在标签映射字典中")return label_map[pred_label]
(1)关键步骤解析
  1. 图像预处理一致性:预测时的预处理必须与测试集一致(本项目使用data_transforms['valid']),否则模型输入格式不匹配,预测结果会失真。
  2. 增加 batch 维度:模型训练和测试时输入都是批次数据([batch, C, H, W]),而单张图片是 [C, H, W],需通过unsqueeze(0)在第 0 维(batch 维)增加一个维度,变为 [1, C, H, W]。
  3. 标签映射label_map(如label_to_food)将数字标签(如 0)映射为食物名称(如 “八宝粥”),让预测结果更直观。

7.2 预测实战与结果展示

预测代码如下,用户输入图片路径和真实标签,模型输出预测结果并对比:

# 标签映射字典(与数据集标签对应)
label_to_food = {0: "八宝粥", 1: "巴旦木", 2: "白萝卜", 3: "板栗", 4: "菠萝",5: "草莓", 6: "蛋", 7: "蛋挞", 8: "骨肉相连", 9: "瓜子",10: "哈密瓜", 11: "汉堡", 12: "胡萝卜", 13: "火龙果", 14: "鸡翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯条", 19: "炸鸡"
}# 用户输入
image_path = input("请输入图片路径:")
true_food = input("请输入该图片的真实食物名称:")# 执行预测
predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food
)# 输出结果
print("\n" + "-" * 50)
print(f"预测结果:{predicted_food}")
print(f"真实结果:{true_food}")
print(f"判断:{'预测正确' if predicted_food == true_food else '预测错误'}")
print("-" * 50)
(1)预测示例

假设用户输入:

  • 图片路径:"D:\食物分类\food_dataset\test\八宝粥\img_八宝粥罐_81.jpeg"(一张八宝粥图片)
  • 真实食物名称:八宝粥

模型输出:

请输入图片路径:./test_images/hamburger.jpg
请输入该图片的真实食物名称:汉堡--------------------------------------------------
预测结果:汉堡
真实结果:汉堡
判断:预测正确
--------------------------------------------------
(2)预测错误原因

样本数量不足

训练的轮数过少

http://www.xdnf.cn/news/1447399.html

相关文章:

  • 基础看门狗--idf开发esp32s3
  • PNP具身解读——RSS2025论文加州伯克利RLDG: 通过强化学习实现机器人通才策略提炼。
  • 基于物联网的智慧用电云平台构建与火灾防控应用研究
  • 复杂网络环境不用愁,声网IoT多通道传输实战经验丰富
  • Coze使用教程-插件
  • 袋鼠云产品功能更新报告14期|实时开发,效率再升级!
  • Kafka面试精讲 Day 6:Kafka日志存储结构与索引机制
  • 浏览器插件开发--通过调用本地nmap实现nmap插件扫描
  • python如何解决html格式不规范问题
  • Android使用内存压力测试工具 StressAppTest
  • [嵌入式embed][Qt]Qt5.12+Opencv4.x+Cmake4.x_用Qt编译linux-Opencv库 测试
  • 显存与内存
  • 【甲烷数据】MethaneSAT 卫星遥感数据
  • 使用DCGAN实现动漫图像生成
  • 树莓集团产教融合:数字学院践行职业教育“实体化运营”要求
  • Ubuntu 系统 LVM 逻辑卷扩容教程
  • 中小企业 AI 转型难?成本、技术、人才三重困境下,轻量化解决方案来了
  • 单位冲击响应频谱
  • python-对图片中的头像进行抠图
  • 确定软件需求的方法
  • 小青苔是什么?
  • C语言(长期更新)第13讲:指针详解(三)
  • GTH收发器初始化和复位全解析
  • 面试复习题-kotlin
  • ArcGIS与GISBox对比:中小企业GIS工具的高门槛与零门槛之选
  • Dify部署全攻略:从零开始搭建AI应用开发平台
  • 【高级】系统架构师 | 信息系统战略规划、EAI 与新技术
  • 华为HCIP、HCIE认证:自学与培训班的抉择
  • 《苍穹外卖》开发环境搭建_后端环境搭建【简单易懂注释版】
  • 牛子图论1(二分图+连通性)