DAY 40 超大力王爱学Python
知识点回顾:
- 彩色和灰度图片测试和训练的规范写法:封装在函数中
- 展平操作:除第一个维度batchsize外全部展平
- dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout
作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from typing import Tuple, Callable, Optional# 1. 数据预处理与加载函数
def get_data_loaders(dataset_name: str = 'MNIST', # 可选: 'MNIST'或'CIFAR10'batch_size: int = 64,data_dir: str = './data',num_workers: int = 2
) -> Tuple[DataLoader, DataLoader]:"""获取训练和测试数据加载器,支持灰度(MNIST)和彩色(CIFAR10)数据集"""# 根据数据集类型设置不同的转换if dataset_name == 'MNIST':# 灰度图像转换transform = transforms.Compose([transforms.ToTensor(), # 转换为Tensor并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差])train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)test_dataset = datasets.MNIST(data_dir, train=False, transform=transform)elif dataset_name == 'CIFAR10':# 彩色图像转换transform = transforms.Compose([transforms.ToTensor(), # 转换为Tensor并归一化到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]])train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)test_dataset = datasets.CIFAR10(data_dir, train=False, transform=transform)else:raise ValueError(f"Unsupported dataset: {dataset_name}")# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)return train_loader, test_loader# 2. 通用模型定义(支持灰度和彩色图像)
class Flatten(nn.Module):"""自定义展平层,保留batch维度"""def forward(self, x):return x.view(x.size(0), -1) # 保留batch维度,展平其余维度class ImageClassifier(nn.Module):"""通用图像分类器,支持灰度和彩色图像"""def __init__(self,input_channels: int = 1, # MNIST:1, CIFAR10:3input_size: int = 28, # MNIST:28, CIFAR10:32hidden_size: int = 128,num_classes: int = 10,dropout_rate: float = 0.5):super().__init__()self.model = nn.Sequential(Flatten(), # 展平除batch外的所有维度nn.Linear(input_channels * input_size * input_size, hidden_size),nn.ReLU(),nn.Dropout(dropout_rate), # 训练时随机丢弃神经元nn.Linear(hidden_size, num_classes))def forward(self, x):return self.model(x)# 3. 训练函数
def train(model: nn.Module,train_loader: DataLoader,criterion: nn.Module,optimizer: optim.Optimizer,device: torch.device,epoch: int,log_interval: int = 100
) -> None:"""训练模型一个epoch"""model.train() # 启用训练模式(激活dropout等)running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)# 前向传播optimizer.zero_grad()output = model(data)loss = criterion(output, target)# 反向传播loss.backward()optimizer.step()running_loss += loss.item()# 打印训练进度if batch_idx % log_interval == 0:print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')# 打印平均损失avg_loss = running_loss / len(train_loader)print(f'Epoch {epoch} average loss: {avg_loss:.4f}')# 4. 测试函数
def test(model: nn.Module,test_loader: DataLoader,criterion: nn.Module,device: torch.device
) -> Tuple[float, float]:"""评估模型在测试集上的性能"""model.eval() # 启用评估模式(关闭dropout等)test_loss = 0correct = 0with torch.no_grad(): # 不计算梯度,节省内存和计算资源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item() # 累加批次损失pred = output.argmax(dim=1, keepdim=True) # 获取最大概率的类别correct += pred.eq(target.view_as(pred)).sum().item() # 统计正确预测数# 计算平均损失和准确率test_loss /= len(test_loader)accuracy = 100. * correct / len(test_loader.dataset)print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} 'f'({accuracy:.2f}%)\n')return test_loss, accuracy# 5. 主函数:训练和测试流程
def main(dataset_name: str = 'MNIST',batch_size: int = 64,epochs: int = 5,lr: float = 0.001,dropout_rate: float = 0.5,use_cuda: bool = True
) -> None:"""主函数:整合数据加载、模型训练和测试流程"""# 设置设备device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")print(f"Using device: {device}")# 获取数据加载器train_loader, test_loader = get_data_loaders(dataset_name=dataset_name,batch_size=batch_size)# 确定输入参数if dataset_name == 'MNIST':input_channels = 1input_size = 28num_classes = 10elif dataset_name == 'CIFAR10':input_channels = 3input_size = 32num_classes = 10else:raise ValueError(f"Unsupported dataset: {dataset_name}")# 初始化模型model = ImageClassifier(input_channels=input_channels,input_size=input_size,hidden_size=128,num_classes=num_classes,dropout_rate=dropout_rate).to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)# 训练和测试循环for epoch in range(1, epochs + 1):train(model, train_loader, criterion, optimizer, device, epoch)test(model, test_loader, criterion, device)# 保存模型torch.save(model.state_dict(), f"{dataset_name}_mlp_model.pth")print(f"Model saved as: {dataset_name}_mlp_model.pth")if __name__ == "__main__":# 训练MNIST模型main(dataset_name='MNIST', batch_size=64, epochs=5)# 训练CIFAR10模型(取消注释下面一行)# main(dataset_name='CIFAR10', batch_size=64, epochs=10)
@浙大疏锦行