当前位置: 首页 > news >正文

DAY 40 训练和测试的规范写法

知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 数据预处理函数 - 处理灰度图像(MNIST)
def get_mnist_loaders(batch_size=64):# 灰度图像归一化transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)return train_loader, test_loader# 数据预处理函数 - 处理彩色图像(CIFAR-10)
def get_cifar10_loaders(batch_size=64):# 彩色图像归一化transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # CIFAR-10的均值和标准差])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)return train_loader, test_loader# 定义包含展平和dropout的CNN模型
class SimpleCNN(nn.Module):def __init__(self, in_channels=1, num_classes=10):super(SimpleCNN, self).__init__()# 卷积层部分self.conv_layers = nn.Sequential(nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Dropout(0.25),  # 第一个dropout层nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Dropout(0.25)   # 第二个dropout层)# 展平操作后接全连接层self.fc_layers = nn.Sequential(nn.Flatten(),  # 展平操作,保持batch维度不变nn.Linear(64 * 7 * 7, 128),  # 假设输入尺寸为32x32,经过两次池化后为8x8,64通道nn.ReLU(),nn.Dropout(0.5),  # 全连接层后的dropoutnn.Linear(128, num_classes))def forward(self, x):x = self.conv_layers(x)x = self.fc_layers(x)return x# 训练函数 - 规范的训练流程
def train(model, train_loader, criterion, optimizer, epoch, device):model.train()  # 切换到训练模式,启用dropoutrunning_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)# 梯度清零optimizer.zero_grad()# 前向传播output = model(data)loss = criterion(output, target)# 反向传播和优化loss.backward()optimizer.step()# 统计训练信息running_loss += loss.item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if batch_idx % 100 == 0:print(f'Epoch: {epoch} | Batch: {batch_idx} | Loss: {loss.item():.4f} | Acc: {100.*correct/total:.2f}%')epoch_loss = running_loss / len(train_loader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_acc# 测试函数 - 规范的测试流程
def test(model, test_loader, criterion, device):model.eval()  # 切换到测试模式,关闭dropouttest_loss = 0correct = 0total = 0with torch.no_grad():  # 不计算梯度,节省内存和计算资源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)# 累加损失test_loss += criterion(output, target).item()# 计算准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()test_loss /= len(test_loader)test_acc = 100. * correct / totalprint(f'Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.2f}%')return test_loss, test_acc# 主函数 - 整合整个流程
def main(use_color=False):# 选择数据集if use_color:print("Using CIFAR-10 (color images) dataset...")train_loader, test_loader = get_cifar10_loaders(batch_size=128)in_channels = 3  # 彩色图像3通道else:print("Using MNIST (grayscale images) dataset...")train_loader, test_loader = get_mnist_loaders(batch_size=128)in_channels = 1  # 灰度图像1通道# 初始化模型model = SimpleCNN(in_channels=in_channels, num_classes=10).to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练和测试循环num_epochs = 5train_losses, train_accs, test_losses, test_accs = [], [], [], []for epoch in range(1, num_epochs + 1):print(f"\nEpoch {epoch}/{num_epochs}")train_loss, train_acc = train(model, train_loader, criterion, optimizer, epoch, device)test_loss, test_acc = test(model, test_loader, criterion, device)train_losses.append(train_loss)train_accs.append(train_acc)test_losses.append(test_loss)test_accs.append(test_acc)# 绘制训练过程plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(test_losses, label='Test Loss')plt.title('Loss Curve')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(train_accs, label='Train Accuracy')plt.plot(test_accs, label='Test Accuracy')plt.title('Accuracy Curve')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.tight_layout()plt.show()if __name__ == "__main__":# 运行灰度图像版本(MNIST)main(use_color=False)# 取消注释以下行运行彩色图像版本(CIFAR-10)# main(use_color=True)

http://www.xdnf.cn/news/722899.html

相关文章:

  • <PLC><socket><西门子>基于西门子S7-1200PLC,实现手机与PLC通讯(通过websocket转接)
  • 每日温度(力扣-739)
  • 零知开源——STM32F407VET6驱动Flappy Bird游戏教程
  • 深兰科技董事长陈海波受邀出席2025苏商高质量发展(常州)峰会,共话AI驱动产业升级
  • LVS-DR 负载均衡集群
  • Spring Boot 整合 Spring Security
  • 后端项目中静态文案国际化语言包构建选型
  • 华为云Flexus+DeepSeek征文 | 基于Dify和DeepSeek-R1开发企业级AI Agent全流程指南
  • 什么是Docker容器?
  • 【Linux 基础知识系列】第三篇-Linux 基本命令
  • 探索C++模板STL
  • Vert.x学习笔记-EventLoop工作原理
  • AI赋能开源:如何借助MCP快速解锁开源项目并提交你的首个PR
  • 机房网络设备操作安全管理制度
  • 历年中国农业大学计算机保研上机真题
  • 深入详解DICOMweb:WADO与STOW-RS的技术解析与实现
  • 如何安全地清洁 Windows10/11PC上的SSD驱动器
  • 系统思考:经营决策沙盘
  • 知识图谱增强的大型语言模型编辑
  • 【Linux】vim编辑器
  • 服务器如何配置防火墙管理端口访问?
  • Ubuntu20.04服务器开启路由转发让局域网内其他电脑通过该服务器连接外网
  • 【仿muduo库实现并发服务器】实现时间轮定时器
  • 戴尔AI服务器订单激增至121亿美元,但传统业务承压
  • 24核32G,千兆共享:裸金属服务器的技术原理与优势
  • VRRP 原理与配置:让你的网络永不掉线!
  • Dify运行本地和在线模型
  • Oracle数据库性能优化的最佳实践
  • 【appium】环境安装部署问题记录
  • 达梦数据库——修改、删除物化视图