day39 图像数据与显存
知识点回顾
- 图像数据的格式:灰度和彩色数据
- 模型的定义
- 显存占用的4种地方
- 模型参数+梯度参数
- 优化器参数
- 数据批量所占显存
- 神经元输出中间状态
- batchisize和训练的关系
作业:今日代码较少,理解内容即可
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)# 定义不同batch size的dataloader
batch_sizes = [16, 32, 64, 128]
train_loaders = {bs: DataLoader(train_dataset, batch_size=bs, shuffle=True) for bs in batch_sizes}
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)# 定义简单CNN模型
class CIFARCNN(nn.Module):def __init__(self, num_classes=10):super(CIFARCNN, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, padding=1), # 32x32x16nn.ReLU(),nn.MaxPool2d(kernel_size=2), # 16x16x16nn.Conv2d(16, 32, kernel_size=3, padding=1), # 16x16x32nn.ReLU(),nn.MaxPool2d(kernel_size=2), # 8x8x32nn.Conv2d(32, 64, kernel_size=3, padding=1), # 8x8x64nn.ReLU(),nn.MaxPool2d(kernel_size=2) # 4x4x64)self.classifier = nn.Sequential(nn.Linear(64 * 4 * 4, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x# 显存占用分析函数
def print_memory_usage(prefix=""):if torch.cuda.is_available():print(f"{prefix}GPU显存使用: {torch.cuda.memory_allocated()/1024**2:.2f} MB "f"| 缓存: {torch.cuda.memory_reserved()/1024**2:.2f} MB")# 训练函数
def train_model(batch_size, epochs=5):model = CIFARCNN().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)train_loader = train_loaders[batch_size]train_losses = []test_accuracies = []print(f"\n=== 训练开始: Batch Size = {batch_size} ===")print_memory_usage("初始化后: ")for epoch in range(epochs):model.train()running_loss = 0.0# 模拟梯度累积(当batch_size较小时)gradient_accumulation_steps = max(1, 32 // batch_size)for i, (inputs, labels) in enumerate(train_loader):inputs, labels = inputs.to(device), labels.to(device)# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)loss = loss / gradient_accumulation_steps # 缩放损失# 反向传播loss.backward()# 梯度累积:每gradient_accumulation_steps步更新一次if (i + 1) % gradient_accumulation_steps == 0:optimizer.step()optimizer.zero_grad()running_loss += loss.item() * gradient_accumulation_steps# 仅打印前几个batch的显存使用情况if i < 2:print_memory_usage(f"Epoch {epoch+1}, Batch {i+1}: ")epoch_loss = running_loss / len(train_loader)train_losses.append(epoch_loss)# 测试模型model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_accuracy = 100 * correct / totaltest_accuracies.append(test_accuracy)print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Test Acc: {test_accuracy:.2f}%")return train_losses, test_accuracies# 为不同batch size训练模型
results = {}
for bs in batch_sizes:# 清空GPU缓存if torch.cuda.is_available():torch.cuda.empty_cache()losses, accuracies = train_model(bs, epochs=3)results[bs] = (losses, accuracies)# 可视化结果
plt.figure(figsize=(12, 5))# 绘制损失曲线
plt.subplot(1, 2, 1)
for bs in batch_sizes:plt.plot(results[bs][0], marker='o', label=f'Batch Size={bs}')
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.title('不同Batch Size的训练损失')
plt.legend()
plt.grid(True)# 绘制准确率曲线
plt.subplot(1, 2, 2)
for bs in batch_sizes:plt.plot(results[bs][1], marker='o', label=f'Batch Size={bs}')
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy (%)')
plt.title('不同Batch Size的测试准确率')
plt.legend()
plt.grid(True)plt.tight_layout()
plt.show()# 显存优化演示:梯度检查点
from torch.utils.checkpoint import checkpointclass CheckpointCNN(CIFARCNN):def forward(self, x):def custom_forward(*inputs):x = inputs[0]for module in self.features:x = module(x)return xx = checkpoint(custom_forward, x)x = x.view(x.size(0), -1)x = self.classifier(x)return x# 使用梯度检查点模型
print("\n=== 使用梯度检查点优化显存 ===")
checkpoint_model = CheckpointCNN().to(device)
print_memory_usage("梯度检查点模型初始化后: ")