当前位置: 首页 > news >正文

python训练营打卡第39天

图像数据与显存

知识点回顾

  1. 图像数据的格式:灰度和彩色数据
  2. 模型的定义
  3. 显存占用的4种地方
    1. 模型参数+梯度参数
    2. 优化器参数
    3. 数据批量所占显存
    4. 神经元输出中间状态
  4. batchisize和训练的关系
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 定义可视化函数
def imshow(img, title=None, is_color=True):"""可视化图像,支持灰度图和彩色图"""if is_color:img = img / 2 + 0.5  # 反标准化处理,将图像范围从[-1,1]转回[0,1]npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 调整维度顺序:(通道,高,宽) → (高,宽,通道)else:img = img * 0.3081 + 0.1307  # 反标准化MNIST图像npimg = img.numpy()plt.imshow(npimg[0], cmap='gray')  # 显示灰度图像if title:plt.title(title)plt.axis('off')plt.show()# 定义MNIST数据集的MLP模型
class MNIST_MLP(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = self.layer1(x)x = self.relu(x)x = self.layer2(x)return x# 定义CIFAR-10数据集的MLP模型
class CIFAR10_MLP(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(3072, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x# 训练函数
def train_model(model, train_loader, criterion, optimizer, device, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for i, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}')return model# 评估函数
def evaluate_model(model, test_loader, device):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Accuracy: {accuracy:.2f}%')return accuracy# 主函数:处理MNIST数据集
def mnist_demo():print("="*50)print("MNIST数据集演示")print("="*50)# MNIST数据预处理mnist_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 加载MNIST数据集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)test_dataset = datasets.MNIST(root='./data', train=False, transform=mnist_transform)# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 可视化样本sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()image, label = train_dataset[sample_idx]imshow(image, f'MNIST样本: {label}', is_color=False)# 初始化模型device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = MNIST_MLP().to(device)# 打印模型摘要print("\nMNIST模型结构信息:")summary(model, input_size=(1, 28, 28))# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型print("\n开始训练MNIST模型...")model = train_model(model, train_loader, criterion, optimizer, device, epochs=5)# 评估模型print("\n评估MNIST模型...")evaluate_model(model, test_loader, device)# 主函数:处理CIFAR-10数据集
def cifar10_demo():print("="*50)print("CIFAR-10数据集演示")print("="*50)# CIFAR-10数据预处理cifar_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载CIFAR-10数据集train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar_transform)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=cifar_transform)# CIFAR-10的类别classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 可视化样本sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()image, label = train_dataset[sample_idx]imshow(image, f'CIFAR-10样本: {classes[label]}', is_color=True)# 初始化模型device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = CIFAR10_MLP().to(device)# 打印模型摘要print("\nCIFAR-10模型结构信息:")summary(model, input_size=(3, 32, 32))# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型print("\n开始训练CIFAR-10模型...")model = train_model(model, train_loader, criterion, optimizer, device, epochs=5)# 评估模型print("\n评估CIFAR-10模型...")evaluate_model(model, test_loader, device)if __name__ == "__main__":# 运行MNIST演示mnist_demo()# 运行CIFAR-10演示cifar10_demo()    

@浙大疏锦行

http://www.xdnf.cn/news/738865.html

相关文章:

  • OAuth详解和应用
  • AI互联网辅助工具
  • 8位单通道数据保存为JPG
  • 【有向图 拓扑排序 】P8405 [COCI 2021/2022 #6] Naboj|普及+
  • 为什么arc中,(cons ‘a (cons 1 (cons “foo“ ‘(b) ))) 是(a 1 “foo“ b)
  • 使用函数证明给定的三个数是否能构成三角形
  • 偏序集、哈斯图、Dilworth
  • 如何做好一份技术文档
  • java25
  • python笔面试题汇总
  • 如何选择合适的培养基过滤器
  • python打卡训练营打卡记录day40
  • 案例分享--血管支架的径向力分布评估--DIC数字图像相关技术用于生物医学-高置信度DIC测量
  • 拉深工艺模块——回转体拉深件毛坯尺寸的确定(一)
  • 初探Linux内核:解锁Linux操作系统的基本核心的奥秘(二)
  • Prevent this information from being displayed to the user 修复方案
  • 涨薪技术|0到1学会性能测试第91课-性能测试过程执行、分析、诊断、调节
  • ASR、TTS与语音克隆技术简介
  • QML 滑动与翻转效果(Flickable与Flipable)
  • 小狼毫输入法雾凇拼音输入方案辅码由默认的部件拆字/拼音输入方案修改为五笔画方案
  • 书送希望 智启未来 —— 赛力斯超级工厂携手渝北和合家园小学校开展公益赠书活动
  • JavaSwing之--JPasswordField
  • 系统设计——状态机模型设计经验
  • Linux ClearOS yum无法使用解决备忘
  • Qt Dial(旋钮)
  • 智慧赋能充电桩管理:我国新能源充电桩建设现状与突破路径
  • 【Doris基础】Apache Doris业务场景全解析:从实时数仓到OLAP分析的完美选择
  • Linux操作系统 使用共享内存实现进程通信和同步
  • 近期手上的一个基于Function Grap(类AWS的Lambda)小项目的改造引发的思考
  • URAT接收实验日志,传输无效