从 0 到 1 实现 PyTorch 食物图像分类:核心知识点与完整实
食物图像分类是计算机视觉的经典任务之一,其核心是让机器 “看懂” 图像中的食物类别。随着深度学习的发展,卷积神经网络(CNN)凭借强大的特征提取能力,成为图像分类的主流方案。本文将基于 PyTorch 框架,从代码实战出发,拆解食物图像分类项目中的核心知识点,包括环境搭建、数据预处理、数据集构建、CNN 模型设计、模型训练与测试、单图预测等,带大家从零搭建一个能识别 20 类食物的分类系统。
# 导入必要的库
import torch # PyTorch核心库,用于构建和训练神经网络
from torch import nn # 神经网络模块,包含各种层和损失函数
from torch.utils.data import Dataset, DataLoader # 数据集和数据加载器,用于数据处理
import numpy as np # 数值计算库,可用于数据预处理等
from PIL import Image # 图像处理库,用于读取和处理图像
from torchvision import transforms # 图像转换工具,用于数据增强和预处理
import os # 操作系统接口,用于文件路径处理等# 定义数据转换策略:训练集使用数据增强,验证集/测试集保持一致的基础转换
data_transforms = {'train': # 训练集转换(包含数据增强,增加样本多样性)transforms.Compose([transforms.Resize([300, 300]), # 先将图像调整为300x300transforms.RandomRotation(45), # 随机旋转(-45~45度),增强旋转不变性transforms.CenterCrop(256), # 中心裁剪到256x256,去除旋转后的黑边transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5), # 50%概率垂直翻转transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), # 随机调整亮度、对比度、饱和度和色调transforms.RandomGrayscale(p=0.1), # 10%概率转为灰度图transforms.ToTensor(), # 转为Tensor格式([C, H, W]),并将像素值归一化到[0,1]transforms.Normalize( # 使用ImageNet的均值和标准差进行标准化[0.485, 0.456, 0.406], # 均值(RGB三个通道)[0.229, 0.224, 0.225] # 标准差(RGB三个通道))]),'valid': # 验证集/测试集转换(无增强,保持数据一致性)transforms.Compose([transforms.Resize([256, 256]), # 调整为256x256,与训练集裁剪后尺寸一致transforms.ToTensor(), # 转为Tensor]),
}# -------------------------- 2. 自定义数据集类 --------------------------
class food_dataset(Dataset):"""自定义食物图像数据集类,继承自PyTorch的Dataset用于加载图像路径和对应标签,并进行预处理"""def __init__(self, file_path, transform=None):"""初始化数据集:param file_path: 存储图像路径和标签的文本文件路径:param transform: 图像转换函数(预处理/数据增强)"""self.file_path = file_path # 文本文件路径self.transform = transform # 转换函数self.imgs = [] # 存储所有图像路径self.labels = [] # 存储对应标签# 读取文件列表(每行格式:图片路径 数字标签)with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip() # 去除首尾空格和换行符if not line: # 跳过空行continue# 按空格分割路径和标签(假设格式严格,无多余空格)img_path, label = line.split(' ')self.imgs.append(img_path)self.labels.append(label)def __len__(self):"""返回数据集样本数量"""return len(self.imgs)def __getitem__(self, index):"""根据索引获取单个样本(图像和标签):param index: 样本索引:return: 处理后的图像张量和标签张量"""# 读取图片并强制转为RGB(避免灰度图导致的通道数不匹配问题)try:image = Image.open(self.imgs[index]).convert('RGB') # 确保3通道输入except Exception as e:# 捕获读取错误,便于调试raise ValueError(f"读取图片 {self.imgs[index]} 失败:{e}")# 应用转换(预处理/数据增强)if self.transform:image = self.transform(image)# 处理标签:转为整数类型的张量(PyTorch分类任务要求标签为long类型)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label# 加载数据集
# 注意:需确保train.txt和test.txt文件存在,每行格式为「图片路径 数字标签」
try:# 加载训练集(使用训练集转换)training_data = food_dataset(file_path='./train.txt', transform=data_transforms['train'])# 加载测试集(使用验证集转换)test_data = food_dataset(file_path='./test.txt', transform=data_transforms['valid'])
except FileNotFoundError:# 捕获文件不存在错误,提示用户raise FileNotFoundError("请确保 train.txt 和 test.txt 文件在当前目录下")# 创建数据加载器(批量加载数据,支持打乱和多进程)
train_dataloader = DataLoader(training_data,batch_size=8, # 批大小:每次加载8张图片shuffle=True # 训练时打乱数据顺序,增强训练效果
)
test_dataloader = DataLoader(test_data,batch_size=8, # 测试时也用相同批大小shuffle=True # 测试时打乱不影响结果,主要便于观察不同样本
)# 设备配置:优先使用GPU(cuda),其次是Apple M系列芯片(mps),最后是CPU
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device') # 打印使用的设备# 定义CNN模型(卷积神经网络)
class CNN(nn.Module):"""自定义卷积神经网络模型,用于食物图像分类包含4个卷积块和1个全连接输出层"""def __init__(self):super().__init__() # 调用父类nn.Module的初始化方法# 第一个卷积块:1次卷积 + ReLU激活 + 最大池化self.conv1 = nn.Sequential(# 卷积层:输入3通道(RGB),输出16通道,卷积核5x5,步长1,填充2(保持尺寸)nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(), # 激活函数,引入非线性nn.MaxPool2d(kernel_size=2), # 最大池化:尺寸减半(256→128))# 第二个卷积块:2次卷积 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), # 输入16通道,输出32通道nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2), # 输入32通道,输出32通道nn.ReLU(),nn.MaxPool2d(2), # 尺寸减半(128→64))# 第三个卷积块:2次卷积 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2), # 输入32通道,输出64通道nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2), # 输入64通道,输出128通道nn.ReLU(),nn.MaxPool2d(2), # 尺寸减半(64→32))# 第四个卷积块:1次卷积 + ReLU(无池化,保持尺寸)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2), # 输入128通道,输出128通道nn.ReLU(), # 输出尺寸:32×32,通道数128)# 全连接输出层:将特征映射到20个类别(食物种类)# 输入尺寸计算:128通道 × 32高 × 32宽(经多次池化后的特征图尺寸)self.out = nn.Linear(128 * 32 * 32, 20) # 20类:需与标签数量一致def forward(self, x):"""前向传播:定义数据在网络中的流动路径:param x: 输入张量,形状为[batch_size, 3, 256, 256]:return: 输出张量,形状为[batch_size, 20](各类别的预测分数)"""x = self.conv1(x) # 经第一个卷积块处理x = self.conv2(x) # 经第二个卷积块处理x = self.conv3(x) # 经第三个卷积块处理x = self.conv4(x) # 经第四个卷积块处理x = x.view(x.size(0), -1) # 展平特征图:[batch_size, 128*32*32]output = self.out(x) # 经全连接层输出预测结果return output# -------------------------- 训练与测试函数 --------------------------
def train(dataloader, model, loss_fn, optimizer):"""训练模型的函数:param dataloader: 训练数据集加载器:param model: 待训练的模型:param loss_fn: 损失函数(用于计算预测误差):param optimizer: 优化器(用于更新模型参数)"""model.train() # 开启训练模式(启用Dropout、BatchNorm等训练特定行为)batch_size_num = 1 # 记录当前批次编号for X, y in dataloader:# 将数据移动到指定设备(GPU/CPU)X, y = X.to(device), y.to(device)# 前向传播:计算模型预测结果pred = model(X)# 计算损失(预测值与真实标签的差距)loss = loss_fn(pred, y)# 反向传播与参数更新optimizer.zero_grad() # 清空上一轮的梯度(避免梯度累积)loss.backward() # 反向传播计算梯度optimizer.step() # 根据梯度更新模型参数# 打印损失(每2个batch打印一次,便于监控训练过程)loss_val = loss.item() # 获取损失的标量值if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f} [batch: {batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):model.eval() # 开启评估模式(关闭Dropout、固定BatchNorm参数等)size = len(dataloader.dataset) # 测试集总样本数num_batches = len(dataloader) # 测试集批次数test_loss, correct = 0, 0 # 总损失和正确预测数# 关闭梯度计算(测试时不需要更新参数,节省计算资源)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device) # 数据移至设备pred = model(X) # 预测test_loss += loss_fn(pred, y).item() # 累加损失# 统计正确预测数:取预测概率最大的类别与真实标签比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batches # 平均损失correct /= size # 准确率print(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")# -------------------------- 单张图片预测函数 --------------------------
def predict_single_image(image_path, model, transform, device, label_map):"""对单张图片进行预测:param image_path: 图片路径:param model: 训练好的模型:param transform: 图像预处理函数(与测试集一致):param device: 计算设备:param label_map: 标签映射字典(数字标签→食物名称):return: 预测的食物名称"""# 读取并预处理图片(与测试集预处理一致)image = Image.open(image_path).convert('RGB') # 确保3通道image = transform(image) # 应用预处理(Resize和ToTensor)# 增加batch维度(模型要求输入格式为[batch, C, H, W],这里batch=1)image = image.unsqueeze(0).to(device)# 模型预测model.eval() # 开启评估模式with torch.no_grad(): # 关闭梯度计算pred_logits = model(image) # 得到预测分数(logits)# 取概率最大的类别标签(argmax(1)按行取最大值索引)pred_label = pred_logits.argmax(1).item()# 映射为食物名称if pred_label not in label_map:raise KeyError(f"预测标签 {pred_label} 不在标签映射字典中")return label_map[pred_label]# -------------------------- 主程序 --------------------------
if __name__ == "__main__":# 初始化模型、损失函数、优化器model = CNN().to(device) # 创建模型并移至设备loss_fn = nn.CrossEntropyLoss() # 多分类问题常用交叉熵损失# Adam优化器:自适应学习率,训练效果较好,学习率0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型(100轮)epochs = 100for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n----------------------------")train(train_dataloader, model, loss_fn, optimizer)print("Training Done!")# 测试模型在测试集上的性能test(test_dataloader, model, loss_fn)# 定义标签映射字典:数字标签→食物名称# 需与数据集的标签完全对应(顺序和数量一致)label_to_food = {0: "八宝粥", 1: "巴旦木", 2: "白萝卜", 3: "板栗", 4: "菠萝",5: "草莓", 6: "蛋", 7: "蛋挞", 8: "骨肉相连", 9: "瓜子",10: "哈密瓜", 11: "汉堡", 12: "胡萝卜", 13: "火龙果", 14: "鸡翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯条", 19: "炸鸡"}# 输入图片路径并预测image_path = input("请输入图片路径:") # 用户输入待预测图片路径true_food = input("请输入该图片的真实食物名称:") # 用户输入真实标签(用于对比)# 执行预测并输出结果predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food)# 输出对比结果print("\n" + "-" * 50)print(f"预测结果:{predicted_food}")print(f"真实结果:{true_food}")print(f"判断:{'预测正确' if predicted_food == true_food else '预测错误'}")print("-" * 50)
二、数据预处理:让数据 “适配” 模型
在深度学习中,数据预处理的质量直接影响模型性能。原始图像可能存在尺寸不一、像素值范围差异大、样本数量不足等问题,需通过预处理将其转化为模型可接受的格式,并通过数据增强提升模型泛化能力。
本项目的预处理逻辑集中在data_transforms
字典中,分 “训练集” 和 “验证集 / 测试集” 两种策略,我们逐一拆解其设计思路。
2.1 为什么要区分训练集与验证集预处理?
- 训练集:需要通过 “数据增强” 增加样本多样性,避免模型过拟合(即模型只记住训练样本,对新样本识别能力差)。
- 验证集 / 测试集:需保持数据的 “真实性”,仅进行基础预处理(如 Resize、ToTensor),确保评估结果能反映模型的实际泛化能力。
2.2 训练集数据增强:每一步的作用与原理
训练集的预处理链为:Resize → RandomRotation → CenterCrop → RandomHorizontalFlip → RandomVerticalFlip → ColorJitter → RandomGrayscale → ToTensor → Normalize
,我们逐个解析:
(1)Resize ([300, 300]):统一初始尺寸
将所有图像调整为 300×300 像素。为什么不直接调整为最终的 256×256?因为后续会进行旋转和裁剪,预留一定尺寸可避免旋转后出现黑边。
(2)RandomRotation (45):随机旋转
随机将图像旋转 - 45°~45°。食物在拍摄时可能有不同角度(如躺着的汉堡、竖放的胡萝卜),旋转增强能让模型对角度不敏感,提升鲁棒性。
(3)CenterCrop (256):中心裁剪
将旋转后的图像从中心裁剪为 256×256。旋转会导致图像边缘出现黑边,裁剪可去除黑边,同时将图像尺寸统一为模型输入尺寸(256×256)。
(4)RandomHorizontalFlip (p=0.5) & RandomVerticalFlip (p=0.5):随机翻转
- 水平翻转(50% 概率):模拟 “左右镜像” 的食物(如翻转后的草莓外观不变)。
- 垂直翻转(50% 概率):模拟 “上下颠倒” 的场景(如掉落的薯条)。
翻转操作不改变食物的核心特征,但能增加样本多样性,且计算成本低。
(5)ColorJitter (0.1, 0.1, 0.1, 0.1):随机颜色抖动
调整图像的亮度、对比度、饱和度、色调,各参数的取值范围为 0~1(0 表示不调整,1 表示最大调整幅度)。
食物图像的颜色易受光照影响(如白天和夜晚拍摄的青菜颜色不同),颜色抖动能让模型对光照变化不敏感。
(6)RandomGrayscale (p=0.1):随机灰度化
10% 概率将彩色图像转为灰度图。虽然食物的颜色是重要特征,但灰度化能迫使模型关注食物的形状、纹理等更本质的特征,避免过度依赖颜色信息(如红色的草莓和红色的圣女果,需通过形状区分)。
(7)ToTensor ():转为 Tensor 格式
将 PIL 图像(H×W×C,像素值 0~255)转为 PyTorch Tensor(C×H×W,像素值归一化到 0~1)。
- 维度转换:模型要求输入为 “通道优先”(C×H×W),而 PIL 图像是 “高度优先”(H×W×C),需通过 ToTensor 调整。
- 归一化:将像素值从 0~255 缩放到 0~1,避免大数值导致模型梯度爆炸。
(8)Normalize ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):标准化
使用 ImageNet 数据集的均值和标准差对 Tensor 进行标准化,公式为:
标准化后像素值 = (原始像素值 - 均值) / 标准差
为什么用 ImageNet 的参数?因为本项目后续可扩展为迁移学习(使用预训练模型),而预训练模型是在 ImageNet 上训练的,使用相同的标准化参数能让模型更快收敛。
transforms.Compose([transforms.Resize([300, 300]), # 先将图像调整为300x300transforms.RandomRotation(45), # 随机旋转(-45~45度),增强旋转不变性transforms.CenterCrop(256), # 中心裁剪到256x256,去除旋转后的黑边transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5), # 50%概率垂直翻转transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), # 随机调整亮度、对比度、饱和度和色调transforms.RandomGrayscale(p=0.1), # 10%概率转为灰度图transforms.ToTensor(), # 转为Tensor格式([C, H, W]),并将像素值归一化到[0,1]transforms.Normalize( # 使用ImageNet的均值和标准差进行标准化[0.485, 0.456, 0.406], # 均值(RGB三个通道)[0.229, 0.224, 0.225] # 标准差(RGB三个通道))
2.3 验证集预处理
验证集的预处理链为:Resize([256, 256]) → ToTensor()
,仅保留基础操作:
- Resize ([256, 256]):直接将图像调整为 256×256,无需旋转(避免引入非真实样本)。
- ToTensor ():与训练集一致,确保数据格式统一。
(注:代码中验证集未做 Normalize,实际项目中建议与训练集保持一致,此处可根据需求调整)
三、自定义 Dataset:PyTorch 数据加载的核心
PyTorch 通过Dataset
和DataLoader
实现数据加载,其中Dataset
负责 “定义数据来源和格式”,DataLoader
负责 “批量加载和并行处理”。本项目自定义了food_dataset
类,用于加载食物图像和对应标签,我们详细解析其实现逻辑。
3.1 Dataset 的核心作用
Dataset
是一个抽象类,要求子类必须实现三个方法:
__init__
:初始化数据集(读取文件列表、加载预处理函数)。__len__
:返回数据集的总样本数。__getitem__
:根据索引返回单个样本(图像 + 标签)。
这三个方法确保了 PyTorch 能高效地迭代访问数据。
3.2 food_dataset 类逐方法解析
(1)init:初始化数据列表
def __init__(self, file_path, transform=None):self.file_path = file_path # 存储图像路径和标签的txt文件路径self.transform = transform # 预处理函数self.imgs = [] # 存储所有图像路径self.labels = [] # 存储对应标签# 读取txt文件,解析图像路径和标签with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip() # 去除首尾空格和换行符if not line: # 跳过空行(避免解析错误)continueimg_path, label = line.split(' ') # 按空格分割路径和标签self.imgs.append(img_path)self.labels.append(label)
- txt 文件格式要求:每行需包含 “图像路径” 和 “数字标签”,用空格分隔
其中 “0” 对应 “八宝粥”,“1” 对应 “巴旦木”,需与后续label_to_food
字典一致。
(2)len:返回样本总数
def __len__(self):return len(self.imgs)
简单直接,返回self.imgs
的长度),DataLoader
会通过该方法确定迭代次数。
(3)getitem:返回单个样本
def __getitem__(self, index):# 读取图像并强制转为RGB(避免灰度图通道数问题)try:image = Image.open(self.imgs[index]).convert('RGB')except Exception as e:raise ValueError(f"读取图片 {self.imgs[index]} 失败:{e}")# 应用预处理if self.transform:image = self.transform(image)# 处理标签:转为int64类型Tensor(PyTorch分类任务要求)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label
这是Dataset
的核心方法,需重点关注三个细节:
- 强制 RGB 格式:
convert('RGB')
确保所有图像都是 3 通道(避免部分灰度图是 1 通道,导致模型输入维度不匹配)。 - 异常处理:
try-except
捕获图像读取错误(如路径错误、图像损坏),并明确提示错误位置,便于调试。 - 标签类型:将标签转为
torch.int64
(即 LongTensor),因为 PyTorch 的CrossEntropyLoss
要求标签为 Long 类型。
3.3 如何准备自己的数据集?
- 收集图像:每个食物类别收集至少 100 张图像(样本越多,模型性能越好),建议按类别分文件夹存储
- 生成 txt 文件:编写脚本遍历图像文件夹,生成
train.txt
和test.txt
import os# 数据集根目录 train_root = "./dataset/train" test_root = "./dataset/test" # 标签映射(与后续一致) label_to_food = {0: "八宝粥", 1: "巴旦木", ..., 19: "炸鸡"} # 反向映射:食物名称→数字标签 food_to_label = {v: k for k, v in label_to_food.items()}# 生成train.txt with open("train.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(train_root):food_dir = os.path.join(train_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")# 生成test.txt(逻辑同上) with open("test.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(test_root):food_dir = os.path.join(test_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")
- 检查路径:确保 txt 文件中的图像路径与实际文件路径一致
四、DataLoader:批量加载与并行处理
Dataset
定义了数据的 “来源”,而DataLoader
则负责将数据 “批量加载” 到模型中,并支持并行处理,提升数据加载速度。
4.1 DataLoader 的核心参数解析
本项目的DataLoader
初始化代码如下:
train_dataloader = DataLoader(training_data, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=8, shuffle=True)
核心参数含义:
batch_size
:每次加载的样本数量(批大小),需根据 GPU 显存调整shuffle
:是否打乱数据顺序。训练集设为True
,验证集可设为False
,本项目测试集设为True
是为了观察不同样本的预测效果。
4.2 DataLoader 与 Dataset 的协作流程
DataLoader
的工作流程可概括为:
- 调用
Dataset.__len__()
获取总样本数,计算总批次数(总样本数 //batch_size)。 - 若
shuffle=True
,则在每个 epoch(训练轮次)开始前打乱样本索引。 - 对每个批次,根据索引调用
Dataset.__getitem__()
获取单个样本,组装成一个批次的 Tensor(形状为 [batch_size, C, H, W])。 - 将批次数据移动到指定设备(CUDA/MPS/CPU),供模型训练或测试。
4.3 数据加载到设备的逻辑
X, y = X.to(device), y.to(device)
其中device
是通过以下代码确定的:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
- 为什么要移动设备? 模型和数据必须在同一设备上才能进行计算(如模型在 CUDA 上,数据也需在 CUDA 上),否则会报错。
- 设备优先级:优先使用 CUDA(NVIDIA GPU),其次是 MPS(Apple M 系列),最后是 CPU
五、CNN 模型构建:从卷积到全连接的特征提取
卷积神经网络(CNN)是图像分类的核心,其通过 “卷积层提取局部特征→池化层降维→全连接层分类” 的流程,实现对图像的识别。本项目的 CNN 模型包含 4 个卷积块和 1 个全连接层,我们逐一解析其设计思路和尺寸计算。
5.1 CNN 的核心组件与作用
在解析代码前,先回顾 CNN 的三个核心组件:
- 卷积层(Conv2d):通过卷积核滑动提取图像的局部特征(如边缘、纹理、形状),输出 “特征图”(Feature Map)。
- 激活函数(ReLU):引入非线性,让模型能拟合复杂的特征关系(避免线性模型的表达能力不足)。
- 池化层(MaxPool2d):对特征图进行下采样,降低维度和计算量,同时增强模型对特征位置的鲁棒性。
5.2 模型代码逐块解析
模型定义代码如下,我们按 “卷积块 1→卷积块 2→卷积块 3→卷积块 4→全连接层” 的顺序解析:
class CNN(nn.Module):def __init__(self):super().__init__()# 卷积块1:1次卷积 + ReLU + 最大池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2), # 池化后尺寸:256→128)# 卷积块2:2次卷积 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2), # 尺寸:128→64)# 卷积块3:2次卷积 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2), # 尺寸:64→32)# 卷积块4:1次卷积 + ReLU(无池化)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2),nn.ReLU(), # 输出尺寸:32×32,通道数128)# 全连接层:映射到20类self.out = nn.Linear(128 * 32 * 32, 20)
池化层参数计算:尺寸减半
MaxPool2d(kernel_size=2)
表示池化核大小为 2×2,步长默认等于核大小(即 2),因此输出尺寸为输入尺寸的 1/2:
conv1
池化前尺寸:256×256 → 池化后:128×128conv2
池化前尺寸:128×128 → 池化后:64×64conv3
池化前尺寸:64×64 → 池化后:32×32
各卷积块的输出特征
- conv1:输出特征图形状为 [batch_size, 16, 128, 128],提取的是图像的低级特征(如边缘、颜色块)。
- conv2:输出形状为 [batch_size, 32, 64, 64],通过 2 次卷积提取更复杂的特征(如食物的局部轮廓)。
- conv3:输出形状为 [batch_size, 128, 32, 32],通道数增加到 128,特征更抽象(如食物的结构特征)。
- conv4:输出形状为 [batch_size, 128, 32, 32],无池化,进一步细化特征(避免池化导致的特征损失)。
(4)全连接层:从特征到分类
全连接层self.out
的输入维度是128×32×32
,这是由conv4
的输出特征图形状决定的:
conv4
输出:[batch_size, 128, 32, 32] → 展平后为 [batch_size, 128×32×32](展平操作在forward
中通过x.view(x.size(0), -1)
实现)。- 输出维度:20,对应 20 种食物类别,每个维度输出该类别的 “预测分数”(后续通过
argmax
取分数最高的类别作为预测结果)。
5.3 forward 方法:定义数据流动路径
def forward(self, x):x = self.conv1(x) # 经卷积块1处理x = self.conv2(x) # 经卷积块2处理x = self.conv3(x) # 经卷积块3处理x = self.conv4(x) # 经卷积块4处理x = x.view(x.size(0), -1) # 展平:[batch_size, 128*32*32]output = self.out(x) # 全连接层输出return output
- 展平操作:
x.view(x.size(0), -1)
将 4 维特征图(batch, C, H, W)转为 2 维张量(batch, C×H×W),因为全连接层仅接受 2 维输入。 - 数据流动:输入图像([batch, 3, 256, 256])→ 卷积块 1→2→3→4 → 展平 → 全连接层 → 输出([batch, 20])。
5.4 模型初始化与设备移动
模型初始化代码如下:
model = CNN().to(device)
CNN()
创建模型实例,to(device)
将模型参数移动到指定设备(CUDA/MPS/CPU),确保模型和数据在同一设备上计算。
六、模型训练与测试:从损失下降到性能评估
模型构建完成后,需通过训练让模型 “学习” 食物特征,再通过测试评估模型的泛化能力。本项目定义了train
和test
两个函数,分别实现训练和测试逻辑。
6.1 训练函数:让模型 “学习”
训练函数的核心是 “前向传播计算损失→反向传播更新参数”,代码如下:
def train(dataloader, model, loss_fn, optimizer):model.train() # 开启训练模式(启用Dropout、BatchNorm训练行为)batch_size_num = 1 # 批次编号,用于打印损失for X, y in dataloader:# 数据移动到设备X, y = X.to(device), y.to(device)# 1. 前向传播:计算预测结果pred = model(X)# 2. 计算损失:预测值与真实标签的差距loss = loss_fn(pred, y)# 3. 反向传播与参数更新optimizer.zero_grad() # 清空上一轮梯度(避免累积)loss.backward() # 反向传播计算梯度optimizer.step() # 根据梯度更新模型参数# 打印损失(每2个批次打印一次)loss_val = loss.item()if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f} [batch: {batch_size_num}]")batch_size_num += 1
(1)关键步骤解析
model.train():开启训练模式,对含有 Dropout、BatchNorm 的模型至关重要:
- Dropout:训练时随机 “关闭” 部分神经元,防止过拟合;测试时不关闭。
- BatchNorm:训练时使用批次的均值和方差归一化;测试时使用训练阶段累积的均值和方差。
前向传播(Forward Pass):
pred = model(X)
:将批次数据输入模型,得到预测结果([batch, 20])。loss = loss_fn(pred, y)
:计算损失,本项目使用CrossEntropyLoss
(多分类任务的常用损失函数)。
反向传播(Backward Pass)与参数更新:
optimizer.zero_grad()
:清空梯度。若不清空,梯度会累积到上一轮,导致参数更新错误。loss.backward()
:根据损失计算各参数的梯度(optimizer.step()
:根据梯度更新模型参数
(2)损失函数与优化器选择
本项目使用的损失函数和优化器如下:
loss_fn = nn.CrossEntropyLoss() # 多分类交叉熵损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam优化器
- CrossEntropyLoss:适用于多分类任务
- Adam 优化器:自适应学习率优化器,收敛速度快,对学习率不敏感,是深度学习中最常用的优化器之一。
lr=0.001
是常用的初始学习率,可根据训练情况调整
6.2 测试函数:评估模型泛化能力
测试函数的核心是 “计算模型在测试集上的准确率和平均损失”,代码如下:
def test(dataloader, model, loss_fn):model.eval() # 开启评估模式(关闭Dropout、固定BatchNorm)size = len(dataloader.dataset) # 测试集总样本数num_batches = len(dataloader) # 测试集批次数test_loss, correct = 0, 0 # 总损失和正确预测数# 关闭梯度计算(节省资源,避免参数更新)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累加损失test_loss += loss_fn(pred, y).item()# 累加正确预测数correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batchescorrect /= sizeprint(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")
(1)关键步骤解析
- model.eval():开启评估模式,与
model.train()
对应,确保模型在测试时的行为与训练时一致(如关闭 Dropout)。 - torch.no_grad():上下文管理器,关闭梯度计算。测试时无需更新参数,关闭梯度可大幅减少内存占用和计算时间。
- 准确率计算:
pred.argmax(1)
:对每个样本,取预测分数最高的类别(维度 1 是类别维度,[batch, 20]→[batch, 1])。(pred.argmax(1) == y)
:比较预测类别与真实标签,得到布尔张量(True = 正确,False = 错误)。type(torch.float).sum().item()
:将布尔张量转为 float(True=1,False=0),求和得到正确预测数,再转为 Python 标量。
- 结果解读:
- Accuracy:准确率(正确预测数 / 总样本数),反映模型的整体识别能力,越高越好。
- Avg Loss:平均损失,反映模型预测值与真实标签的平均差距,越低越好。
6.3 训练流程与轮次设置
训练流程代码如下:
# 训练轮次(epochs)
epochs = 100
for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Training Done!")# 测试模型
test(test_dataloader, model, loss_fn)
- epochs(训练轮次):表示模型将遍历整个训练集的次数。本项目设为 100,可根据实际情况调整:
- 若训练损失仍在下降,可增加 epochs;
- 若训练损失下降但测试损失上升(过拟合),可减少 epochs 或加入早停机制。
- 训练与测试顺序:每轮训练后可加入测试(如在
train
后调用test
),便于监控模型是否过拟合;本项目在所有训练完成后测试,适用于快速验证。
七、单张图片预测:模型的实际应用
训练完成后,需将模型用于实际场景 —— 对单张食物图片进行分类。本项目定义了predict_single_image
函数,实现从图像读取到类别输出的完整流程。
7.1 预测函数解析
def predict_single_image(image_path, model, transform, device, label_map):# 1. 读取并预处理图像(与测试集一致)image = Image.open(image_path).convert('RGB') # 强制RGBimage = transform(image) # 应用预处理(Resize + ToTensor)# 2. 增加batch维度(模型要求输入为[batch, C, H, W])image = image.unsqueeze(0).to(device)# 3. 模型预测model.eval() # 开启评估模式with torch.no_grad():pred_logits = model(image) # 预测分数(logits)pred_label = pred_logits.argmax(1).item() # 取最高分数类别# 4. 映射为食物名称if pred_label not in label_map:raise KeyError(f"预测标签 {pred_label} 不在标签映射字典中")return label_map[pred_label]
(1)关键步骤解析
- 图像预处理一致性:预测时的预处理必须与测试集一致(本项目使用
data_transforms['valid']
),否则模型输入格式不匹配,预测结果会失真。 - 增加 batch 维度:模型训练和测试时输入都是批次数据([batch, C, H, W]),而单张图片是 [C, H, W],需通过
unsqueeze(0)
在第 0 维(batch 维)增加一个维度,变为 [1, C, H, W]。 - 标签映射:
label_map
(如label_to_food
)将数字标签(如 0)映射为食物名称(如 “八宝粥”),让预测结果更直观。
7.2 预测实战与结果展示
预测代码如下,用户输入图片路径和真实标签,模型输出预测结果并对比:
# 标签映射字典(与数据集标签对应)
label_to_food = {0: "八宝粥", 1: "巴旦木", 2: "白萝卜", 3: "板栗", 4: "菠萝",5: "草莓", 6: "蛋", 7: "蛋挞", 8: "骨肉相连", 9: "瓜子",10: "哈密瓜", 11: "汉堡", 12: "胡萝卜", 13: "火龙果", 14: "鸡翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯条", 19: "炸鸡"
}# 用户输入
image_path = input("请输入图片路径:")
true_food = input("请输入该图片的真实食物名称:")# 执行预测
predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food
)# 输出结果
print("\n" + "-" * 50)
print(f"预测结果:{predicted_food}")
print(f"真实结果:{true_food}")
print(f"判断:{'预测正确' if predicted_food == true_food else '预测错误'}")
print("-" * 50)
(1)预测示例
假设用户输入:
- 图片路径:"D:\食物分类\food_dataset\test\八宝粥\img_八宝粥罐_81.jpeg"(一张八宝粥图片)
- 真实食物名称:八宝粥
模型输出:
请输入图片路径:./test_images/hamburger.jpg
请输入该图片的真实食物名称:汉堡--------------------------------------------------
预测结果:汉堡
真实结果:汉堡
判断:预测正确
--------------------------------------------------
(2)预测错误原因
样本数量不足
训练的轮数过少