python训练营打卡第39天
图像数据与显存
知识点回顾
- 图像数据的格式:灰度和彩色数据
- 模型的定义
- 显存占用的4种地方
- 模型参数+梯度参数
- 优化器参数
- 数据批量所占显存
- 神经元输出中间状态
- batchisize和训练的关系
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from torchsummary import summary# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 定义可视化函数
def imshow(img, title=None, is_color=True):"""可视化图像,支持灰度图和彩色图"""if is_color:img = img / 2 + 0.5 # 反标准化处理,将图像范围从[-1,1]转回[0,1]npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0))) # 调整维度顺序:(通道,高,宽) → (高,宽,通道)else:img = img * 0.3081 + 0.1307 # 反标准化MNIST图像npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像if title:plt.title(title)plt.axis('off')plt.show()# 定义MNIST数据集的MLP模型
class MNIST_MLP(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = self.layer1(x)x = self.relu(x)x = self.layer2(x)return x# 定义CIFAR-10数据集的MLP模型
class CIFAR10_MLP(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(3072, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x# 训练函数
def train_model(model, train_loader, criterion, optimizer, device, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for i, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}')return model# 评估函数
def evaluate_model(model, test_loader, device):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Accuracy: {accuracy:.2f}%')return accuracy# 主函数:处理MNIST数据集
def mnist_demo():print("="*50)print("MNIST数据集演示")print("="*50)# MNIST数据预处理mnist_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 加载MNIST数据集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)test_dataset = datasets.MNIST(root='./data', train=False, transform=mnist_transform)# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 可视化样本sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()image, label = train_dataset[sample_idx]imshow(image, f'MNIST样本: {label}', is_color=False)# 初始化模型device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = MNIST_MLP().to(device)# 打印模型摘要print("\nMNIST模型结构信息:")summary(model, input_size=(1, 28, 28))# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型print("\n开始训练MNIST模型...")model = train_model(model, train_loader, criterion, optimizer, device, epochs=5)# 评估模型print("\n评估MNIST模型...")evaluate_model(model, test_loader, device)# 主函数:处理CIFAR-10数据集
def cifar10_demo():print("="*50)print("CIFAR-10数据集演示")print("="*50)# CIFAR-10数据预处理cifar_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载CIFAR-10数据集train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar_transform)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=cifar_transform)# CIFAR-10的类别classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 可视化样本sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()image, label = train_dataset[sample_idx]imshow(image, f'CIFAR-10样本: {classes[label]}', is_color=True)# 初始化模型device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = CIFAR10_MLP().to(device)# 打印模型摘要print("\nCIFAR-10模型结构信息:")summary(model, input_size=(3, 32, 32))# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型print("\n开始训练CIFAR-10模型...")model = train_model(model, train_loader, criterion, optimizer, device, epochs=5)# 评估模型print("\n评估CIFAR-10模型...")evaluate_model(model, test_loader, device)if __name__ == "__main__":# 运行MNIST演示mnist_demo()# 运行CIFAR-10演示cifar10_demo()
@浙大疏锦行