pytorch与mindspore的简单ViT实现
本文分别使用pytorch和mindsopre构建了一个针对于图像分类任务的简单ViT
pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F# MLP Block
# fc-gelu-dropout-fc-dropout
# input:emb_dim,mlp_ratio,dropout
class MLP(nn.Module):def __init__(self, emb_dim,mlp_ratio=4,dropout=0.1):super().__init__()hidden_dim = int(emb_dim*mlp_ratio)self.fc1=nn.Linear(emb_dim,hidden_dim)self.fc2=nn.Linear(hidden_dim,emb_dim)self.dropout = nn.Dropout(dropout) # 每次调用独立生成def forward(self,x):x = self.fc1(x)x = F.gelu(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return x# Multi-Head Self-Attention
# input:emb_dim,num_heads,dropout
# attention mark - attention weight
class MutiHeadAttention(nn.Module):def __init__(self, emb_dim,num_heads=8,dropout=0.1):super().__init__()assert emb_dim%num_heads==0self.emb_dim = emb_dimself.num_heads = num_headsself.head_dim = emb_dim//num_headsself.scale = self.head_dim ** -0.5self.qkv = nn.Linear(emb_dim,emb_dim*3)self.proj = nn.Linear(emb_dim,emb_dim)self.dropout = nn.Dropout(dropout)def forward(self,x):B,N,C = x.shapeqkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)q,k,v = qkv[0],qkv[1],qkv[2] # (B,num_heads,N,head_dim)attn = (q @ k.transpose(-2,-1))*self.scale #(B,num_heads,N,N)attn = F.softmax(attn,dim=-1)attn = self.dropout(attn)attn = (attn @ v).transpose(1,2).reshape(B,N,C)x = self.proj(attn)x = self.dropout(attn)nn.BatchNorm2dreturn xclass TransformerBlock(nn.Module):def __init__(self, emb_dim,num_heads=8,mlp_ratio=4,dropout=0.1):super().__init__()self.norm1 = nn.LayerNorm(emb_dim)self.norm2 = nn.LayerNorm(emb_dim)self.attn = MutiHeadAttention(emb_dim,num_heads,dropout)self.mlp = MLP(emb_dim,mlp_ratio,dropout)def forward(self,x):x = x + self.attn(self.norm1(x))x = x + self.mlp(self.norm2(x))return xclass PatchEmbedding(nn.Module):def __init__(self, img_size=224,patch_size=16,in_channels=3,emb_dim=768):super().__init__()self.num_patches = (img_size//patch_size) ** 2self.proj = nn.Conv2d(in_channels,emb_dim,kernel_size=patch_size,stride=patch_size)def forward(self,x):B,C,H,W = x.shapex = self.proj(x) # (B,emb_dim,H//patch_size,W//patch_size)x = x.flatten(2).transpose(-1,-2) # (B,num_patch,emb_dim)return xclass ViT(nn.Module):def __init__(self, num_classes=10,img_size=224,patch_size=16,in_channels=3,emb_dim=768,num_heads=8,depth=8,mlp_ratio=4,dropout=0.1):super().__init__()self.patch_emb = PatchEmbedding(img_size,patch_size,in_channels,emb_dim)num_patches = self.patch_emb.num_patchesself.pos_emb = nn.Parameter(torch.zeros(1,num_patches+1,emb_dim))self.cls_token = nn.Parameter(torch.zeros(1,1,emb_dim))self.dropout = nn.Dropout(dropout)self.blocks = nn.ModuleList([TransformerBlock(emb_dim,num_heads,mlp_ratio,dropout) for _ in range(depth)])self.norm = nn.LayerNorm(emb_dim)self.proj = nn.Linear(emb_dim,num_classes)self.init_weights()def init_weights(self):nn.init.trunc_normal_(self.pos_emb,std=0.02)nn.init.trunc_normal_(self.cls_token,std=0.02)self.apply(self._init_weights)def _init_weights(self,m):if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=0.02)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)def forward(self,x):B = x.shape[0]x = self.patch_emb(x)cls_tokens = self.cls_token.expand(B,-1,-1)x = torch.concat([cls_tokens,x],dim=1) x = x + self.pos_embx = self.dropout(x)for block in self.blocks:x = block(x)x = self.norm(x)cls_token_output = x[:,0]return self.proj(cls_token_output)
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import matplotlib.pyplot as plt
import pytorch_build# 6. 创建适合CIFAR-10的小型ViT
def create_tiny_vit():return pytorch_build.ViT(img_size=32,patch_size=4,in_channels=3,num_classes=10,emb_dim=192,depth=6,num_heads=3,mlp_ratio=4.0,dropout=0.1)# 7. 数据加载
def get_cifar10_loaders(batch_size=64):transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)return train_loader, test_loader# 8. 训练函数
def train_epoch(model, train_loader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for batch_idx, (data, targets) in enumerate(train_loader):data, targets = data.to(device), targets.to(device)optimizer.zero_grad()outputs = model(data)loss = criterion(outputs, targets)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()if batch_idx % 200 == 0:print(f' Batch {batch_idx}, Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%')return running_loss / len(train_loader), 100. * correct / total# 9. 测试函数
def test_epoch(model, test_loader, criterion, device):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data, targets in test_loader:data, targets = data.to(device), targets.to(device)outputs = model(data)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()return test_loss / len(test_loader), 100. * correct / total# 10. 训练主函数
def train_vit_on_cifar10(epochs=30, batch_size=64, lr=3e-4):# 设备device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"使用设备: {device}")# 数据train_loader, test_loader = get_cifar10_loaders(batch_size)# 模型model = create_tiny_vit().to(device)total_params = sum(p.numel() for p in model.parameters())print(f"模型参数量: {total_params:,}")# 训练配置criterion = nn.CrossEntropyLoss()optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)# 训练记录train_losses, train_accs = [], []test_losses, test_accs = [], []best_acc = 0print(f"开始训练,共{epochs}个epochs...")for epoch in range(epochs):start_time = time.time()# 训练和测试train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)test_loss, test_acc = test_epoch(model, test_loader, criterion, device)# 更新学习率scheduler.step()# 记录train_losses.append(train_loss)train_accs.append(train_acc)test_losses.append(test_loss)test_accs.append(test_acc)# 保存最佳模型if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), 'best_vit_cifar10.pth')# 输出结果epoch_time = time.time() - start_timecurrent_lr = optimizer.param_groups[0]['lr']print(f'Epoch {epoch+1}/{epochs} ({epoch_time:.1f}s) - LR: {current_lr:.6f}')print(f' Train: Loss {train_loss:.4f}, Acc {train_acc:.2f}%')print(f' Test: Loss {test_loss:.4f}, Acc {test_acc:.2f}% (Best: {best_acc:.2f}%)')print('-' * 60)return {'train_losses': train_losses,'train_accs': train_accs,'test_losses': test_losses,'test_accs': test_accs,'best_acc': best_acc}# 11. 可视化
def plot_results(history):fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))epochs = range(1, len(history['train_losses']) + 1)# Lossax1.plot(epochs, history['train_losses'], 'b-', label='Train Loss')ax1.plot(epochs, history['test_losses'], 'r-', label='Test Loss')ax1.set_xlabel('Epoch')ax1.set_ylabel('Loss')ax1.set_title('Training and Test Loss')ax1.legend()ax1.grid(True)# Accuracyax2.plot(epochs, history['train_accs'], 'b-', label='Train Acc')ax2.plot(epochs, history['test_accs'], 'r-', label='Test Acc')ax2.set_xlabel('Epoch')ax2.set_ylabel('Accuracy (%)')ax2.set_title('Training and Test Accuracy')ax2.legend()ax2.grid(True)plt.tight_layout()plt.show()# 12. 主程序
if __name__ == "__main__":print("=== PyTorch ViT on CIFAR-10 ===\n")# 开始训练history = train_vit_on_cifar10(epochs=20, batch_size=64, lr=3e-4)# 显示结果print(f"\n训练完成!最佳测试精度: {history['best_acc']:.2f}%")# 绘制训练曲线plot_results(history)
mindspore
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer, TruncatedNormal, Constantimport numpy as np# 设置MindSpore上下文
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="CPU") # 可改为GPUclass MLP(nn.Cell):def __init__(self, emb_dim, mlp_ratio=4.0, dropout=0.1):super().__init__()hidden_dim = int(emb_dim * mlp_ratio)self.fc1 = nn.Dense(emb_dim, hidden_dim)self.fc2 = nn.Dense(hidden_dim, emb_dim)self.dropout = nn.Dropout(p=dropout)self.gelu = nn.GELU()def construct(self, x):x = self.fc1(x)x = self.gelu(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return x# 2. Multi-Head Self-Attention
class MultiHeadAttention(nn.Cell):def __init__(self, emb_dim, num_heads=8, dropout=0.1):super().__init__()assert emb_dim % num_heads == 0self.emb_dim = emb_dimself.num_heads = num_headsself.head_dim = emb_dim // num_headsself.scale = self.head_dim ** -0.5self.qkv = nn.Dense(emb_dim, emb_dim * 3)self.proj = nn.Dense(emb_dim, emb_dim)self.dropout = nn.Dropout(p=dropout)# MindSpore操作self.softmax = nn.Softmax(axis=-1)self.transpose = ops.Transpose()self.reshape = ops.Reshape()self.bmm = ops.BatchMatMul()def construct(self, x):B, N, C = x.shape# 生成Q,K,Vqkv = self.qkv(x)qkv = self.reshape(qkv, (B, N, 3, self.num_heads, self.head_dim))qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) # (3, B, num_heads, N, head_dim)q, k, v = qkv[0], qkv[1], qkv[2]# 计算注意力k_t = self.transpose(k, (0, 1, 3, 2)) # (B, num_heads, head_dim, N)attn = self.bmm(q, k_t) * self.scale # (B, num_heads, N, N)attn = self.softmax(attn)attn = self.dropout(attn)# 应用注意力x = self.bmm(attn, v) # (B, num_heads, N, head_dim)x = self.transpose(x, (0, 2, 1, 3)) # (B, N, num_heads, head_dim)x = self.reshape(x, (B, N, C))x = self.proj(x)x = self.dropout(x)return x# 3. Transformer Block
class TransformerBlock(nn.Cell):def __init__(self, emb_dim, num_heads=8, mlp_ratio=4.0, dropout=0.1):super().__init__()self.norm1 = nn.LayerNorm((emb_dim,))self.attn = MultiHeadAttention(emb_dim, num_heads, dropout)self.norm2 = nn.LayerNorm((emb_dim,))self.mlp = MLP(emb_dim, mlp_ratio, dropout)def construct(self, x):# 残差连接x = x + self.attn(self.norm1(x))x = x + self.mlp(self.norm2(x))return x# 4. Patch Embedding
class PatchEmbedding(nn.Cell):def __init__(self, img_size=224, patch_size=16, in_channels=3, emb_dim=768):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size // patch_size) ** 2self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)self.reshape = ops.Reshape()self.transpose = ops.Transpose()def construct(self, x):x = self.proj(x) # (B, emb_dim, H//patch_size, W//patch_size)B, C, H, W = x.shapex = self.reshape(x, (B, C, H * W)) # flattenx = self.transpose(x, (0, 2, 1)) # (B, num_patches, emb_dim)return x# 5. Vision Transformer
class ViT(nn.Cell):def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10,emb_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1):super().__init__()self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, emb_dim)num_patches = self.patch_embed.num_patches# 可学习参数self.pos_emb = Parameter(ops.zeros((1, num_patches+1, emb_dim), ms.float32), name='pos_emb')self.cls_token = Parameter(ops.zeros((1, 1, emb_dim), ms.float32), name='cls_token')self.dropout = nn.Dropout(p=dropout)# Transformer blocksself.blocks = nn.CellList([TransformerBlock(emb_dim, num_heads, mlp_ratio, dropout)for _ in range(depth)])self.norm = nn.LayerNorm((emb_dim,))self.head = nn.Dense(emb_dim, num_classes)# MindSpore操作self.expand_dims = ops.ExpandDims()self.tile = ops.Tile()self.concat = ops.Concat(axis=1)self.init_weights()def init_weights(self):# 初始化位置编码和 cls tokenself.pos_emb.set_data(initializer(TruncatedNormal(0.02), self.pos_emb.shape, self.pos_emb.dtype))self.cls_token.set_data(initializer(TruncatedNormal(0.02), self.cls_token.shape, self.cls_token.dtype))# 初始化其他层for _, cell in self.cells_and_names():if isinstance(cell, nn.Dense):cell.weight.set_data(initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype))if cell.bias is not None:cell.bias.set_data(initializer(Constant(0.0), cell.bias.shape, cell.bias.dtype))elif isinstance(cell, nn.LayerNorm):cell.gamma.set_data(initializer(Constant(1.0), cell.gamma.shape, cell.gamma.dtype))cell.beta.set_data(initializer(Constant(0.0), cell.beta.shape, cell.beta.dtype))def construct(self, x):B = x.shape[0]# Patch embeddingx = self.patch_embed(x)# 添加cls tokencls_tokens = self.tile(self.cls_token, (B, 1, 1))x = self.concat([cls_tokens, x])# 添加位置编码x = x + self.pos_embx = self.dropout(x)# Transformer blocksfor block in self.blocks:x = block(x)# 分类x = self.norm(x)cls_token_final = x[:, 0] # 取cls tokenreturn self.head(cls_token_final)
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision.c_transforms as C
import matplotlib.pyplot as plt
import time
import mindspore_build# 6. 创建小型ViT
def create_tiny_vit():return mindspore_build.ViT(img_size=32,patch_size=4,in_channels=3,num_classes=10,emb_dim=192,depth=6,num_heads=3,mlp_ratio=4.0,dropout=0.1)# 7. CIFAR-10数据预处理
def create_cifar10_dataset(data_path='./cifar-10-batches-bin', batch_size=64, training=True):"""创建CIFAR-10数据集"""if training:# 训练数据变换transform_list = [vision.RandomCrop(32, padding=4),vision.RandomHorizontalFlip(prob=0.5),vision.ToTensor(),vision.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],is_hwc=False)]else:# 测试数据变换transform_list = [vision.ToTensor(),vision.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010],is_hwc=False)]# 创建数据集dataset = ds.Cifar10Dataset(dataset_dir=data_path,usage='train' if training else 'test',shuffle=training)# 应用变换dataset = dataset.map(operations=transform_list, input_columns="image")dataset = dataset.map(operations=transforms.TypeCast(ms.int32), input_columns="label")# 批处理dataset = dataset.batch(batch_size, drop_remainder=training)return dataset# 8. 训练一个epoch
def train_epoch(model, dataset, optimizer, loss_fn):"""训练一个epoch"""def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logitsgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)def train_step(data, label):(loss, logits), grads = grad_fn(data, label)loss = ops.depend(loss, optimizer(grads))return loss, logitssize = dataset.get_dataset_size()model.set_train()total_loss = 0total_correct = 0total_samples = 0batch_count = 0for batch_idx, (data, labels) in enumerate(dataset.create_tuple_iterator()):loss,logits = train_step(data, labels)predictions = ops.Argmax(axis=1)(logits)total_loss += loss.asnumpy()pred_labels = ops.Argmax(axis=-1)(predictions)correct = ops.Equal()(pred_labels, labels).sum()total_correct += correct.asnumpy()total_samples += labels.shape[0]batch_count += 1if batch_idx % 200 == 0:acc = 100. * total_correct / total_samplesprint(f' Batch {batch_idx}, Loss: {loss.asnumpy():.4f}, Acc: {acc:.2f}%')avg_loss = total_loss / batch_countavg_acc = 100. * total_correct / total_samplesreturn avg_loss, avg_acc# 9. 测试函数
def test_epoch(model, dataset, loss_fn):"""测试模型"""model.set_train(False)total_loss = 0total_correct = 0total_samples = 0batch_count = 0for data, labels in dataset.create_tuple_iterator():predictions = model(data)loss = loss_fn(predictions, labels)total_loss += loss.asnumpy()pred_labels = ops.Argmax(axis=-1)(predictions)correct = ops.Equal()(pred_labels, labels).sum()total_correct += correct.asnumpy()total_samples += labels.shape[0]batch_count += 1avg_loss = total_loss / batch_countavg_acc = 100. * total_correct / total_samplesreturn avg_loss, avg_acc# 10. 训练主函数
def train_vit_on_cifar10(epochs=20, batch_size=64, lr=3e-4):"""在CIFAR-10上训练ViT"""print(f"MindSpore版本: {ms.__version__}")print("开始准备CIFAR-10数据集...")# 数据集train_dataset = create_cifar10_dataset('./cifar-10-batches-bin', batch_size, training=True)test_dataset = create_cifar10_dataset('./cifar-10-batches-bin', batch_size, training=False)# 模型model = create_tiny_vit()total_params = sum([param.size for param in model.get_parameters()])print(f"模型参数量: {total_params:,}")# 损失和优化器loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')optimizer = nn.AdamWeightDecay(model.trainable_params(),learning_rate=lr,weight_decay=0.05)# 学习率调度器scheduler = nn.cosine_decay_lr(min_lr=1e-6,max_lr=lr,total_step=epochs * train_dataset.get_dataset_size(),step_per_epoch=train_dataset.get_dataset_size(),decay_epoch=epochs)optimizer = nn.AdamWeightDecay(model.trainable_params(),learning_rate=scheduler,weight_decay=0.05)# 训练记录train_losses, train_accs = [], []test_losses, test_accs = [], []best_acc = 0print(f"开始训练,共{epochs}个epochs...")for epoch in range(epochs):start_time = time.time()# 训练train_loss, train_acc = train_epoch(model, train_dataset, optimizer, loss_fn)# 测试test_loss, test_acc = test_epoch(model, test_dataset, loss_fn)# 记录train_losses.append(train_loss)train_accs.append(train_acc)test_losses.append(test_loss)test_accs.append(test_acc)# 保存最佳模型if test_acc > best_acc:best_acc = test_accms.save_checkpoint(model, 'best_vit_cifar10_ms.ckpt')# 输出结果epoch_time = time.time() - start_timeprint(f'Epoch {epoch+1}/{epochs} ({epoch_time:.1f}s)')print(f' Train: Loss {train_loss:.4f}, Acc {train_acc:.2f}%')print(f' Test: Loss {test_loss:.4f}, Acc {test_acc:.2f}% (Best: {best_acc:.2f}%)')print('-' * 60)return {'train_losses': train_losses,'train_accs': train_accs,'test_losses': test_losses,'test_accs': test_accs,'best_acc': best_acc}# 11. 可视化结果
def plot_results(history):"""绘制训练结果"""fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))epochs = range(1, len(history['train_losses']) + 1)# Loss曲线ax1.plot(epochs, history['train_losses'], 'b-', label='Train Loss')ax1.plot(epochs, history['test_losses'], 'r-', label='Test Loss')ax1.set_xlabel('Epoch')ax1.set_ylabel('Loss')ax1.set_title('Training and Test Loss')ax1.legend()ax1.grid(True)# Accuracy曲线ax2.plot(epochs, history['train_accs'], 'b-', label='Train Acc')ax2.plot(epochs, history['test_accs'], 'r-', label='Test Acc')ax2.set_xlabel('Epoch')ax2.set_ylabel('Accuracy (%)')ax2.set_title('Training and Test Accuracy')ax2.legend()ax2.grid(True)plt.tight_layout()plt.show()# 12. 主程序
if __name__ == "__main__":print("=== MindSpore ViT on CIFAR-10 ===\n")# 开始训练try:history = train_vit_on_cifar10(epochs=15, batch_size=64, lr=3e-4)print(f"\n训练完成!最佳测试精度: {history['best_acc']:.2f}%")# 绘制结果plot_results(history)except Exception as e:print(f"训练过程中出现错误: {e}")print("请确保已安装MindSpore并正确配置环境")print("安装命令: pip install mindspore")