当前位置: 首页 > news >正文

深度学习篇---ShuffleNet网络结构

在 PyTorch 中实现 ShuffleNet(以经典的 ShuffleNet v1 为例),核心是实现它的两个 "招牌动作"——分组卷积通道混洗。我们从最基础的模块开始,一步步搭建,确保你能理解每个操作的作用。

一、先明确 ShuffleNet v1 的核心结构

ShuffleNet v1 的结构可以概括为:

输入(224×224彩色图) → 
初始卷积层 → 初始池化层 → 
3个阶段的ShuffleNet单元(每个阶段包含步长=2的下采样单元+多个步长=1的特征提取单元) → 
全局平均池化 → 全连接层(输出1000类)

其中,ShuffleNet 单元是核心,分为 "步长 = 1"(保持尺寸)和 "步长 = 2"(下采样)两种。

二、PyTorch 实现 ShuffleNet v1 的步骤

步骤 1:导入必要的库

和之前实现其他 CNN 一样,先准备好工具:

import torch  # 核心库
import torch.nn as nn  # 神经网络层
import torch.optim as optim  # 优化器
from torch.utils.data import DataLoader  # 数据加载器
from torchvision import datasets, transforms  # 图像数据处理

步骤 2:实现核心操作 —— 通道混洗(Channel Shuffle)

通道混洗是 ShuffleNet 的标志性创新,用于解决分组卷积的 "信息隔绝" 问题。我们先实现这个操作:

def channel_shuffle(x, groups):"""通道混洗操作:将分组卷积后的通道重新打乱分配x: 输入特征图,形状为(batch_size, channels, height, width)groups: 分组数量"""batch_size, channels, height, width = x.size()# 1. 确保通道数能被分组数整除(ShuffleNet的设计要求)assert channels % groups == 0, "通道数必须是分组数的整数倍"channels_per_group = channels // groups  # 每组的通道数# 2. 通道混洗的核心步骤:# 拆分成 (batch_size, groups, channels_per_group, height, width)x = x.view(batch_size, groups, channels_per_group, height, width)# 交换groups和channels_per_group维度 → (batch_size, channels_per_group, groups, height, width)x = x.transpose(1, 2).contiguous()# 重新展平通道维度 → (batch_size, channels, height, width)x = x.view(batch_size, -1, height, width)return x

通俗解释
假设输入是 8 组(groups=8),每组 32 通道(共 256 通道),通道混洗会把 "8 组 ×32 通道" 变成 "32 组 ×8 通道",让新的每组通道都包含原来 8 组的信息,打破分组间的隔绝。

步骤 3:实现 ShuffleNet 的核心单元

ShuffleNet 有两种单元:步长 = 1 的单元(特征提取,尺寸不变)和步长 = 2 的单元(下采样,尺寸减半)。

3.1 步长 = 1 的 ShuffleNet 单元(特征提取)
class ShuffleNetUnitV1(nn.Module):def __init__(self, in_channels, out_channels, groups, stride=1):super(ShuffleNetUnitV1, self).__init__()self.stride = strideself.groups = groups# 确保输出通道数是分组数的整数倍assert out_channels % groups == 0mid_channels = out_channels // 4  # 中间通道数(降维用,减少计算量)# 主分支:1×1分组卷积 → 3×3深度卷积 → 1×1分组卷积self.main_branch = nn.Sequential(# 1×1分组卷积(降维)nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# 3×3深度卷积(每个通道单独卷积,步长控制是否下采样)nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),# 1×1分组卷积(升维)nn.Conv2d(mid_channels, out_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(out_channels),)# 激活函数(放在单元最后)self.relu = nn.ReLU(inplace=True)def forward(self, x):# 步长=1时,使用残差连接(输入直接加输出)if self.stride == 1:# 先做通道混洗(关键!让分组卷积的信息流通)x = channel_shuffle(x, self.groups)# 主分支计算out = self.main_branch(x)# 残差连接:输入+输出out += xreturn self.relu(out)# 步长=2时的处理(下采样,后面单独实现)else:# 先做通道混洗x = channel_shuffle(x, self.groups)# 主分支计算out = self.main_branch(x)return out
3.2 步长 = 2 的 ShuffleNet 单元(下采样)

步长 = 2 的单元需要下采样(尺寸减半),因此用 "双分支" 设计:

class ShuffleNetDownUnitV1(nn.Module):def __init__(self, in_channels, out_channels, groups):super(ShuffleNetDownUnitV1, self).__init__()self.groups = groups# 确保输出通道数是分组数的整数倍assert out_channels % groups == 0mid_channels = out_channels // 4  # 中间通道数# 主分支:1×1分组卷积 → 3×3深度卷积(步长2) → 1×1分组卷积self.main_branch = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# 步长=2,实现下采样(尺寸减半)nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),nn.Conv2d(mid_channels, out_channels - in_channels, kernel_size=1,  # 输出通道数=总通道数-捷径分支通道数groups=groups, bias=False),nn.BatchNorm2d(out_channels - in_channels),)# 捷径分支:平均池化(步长2,下采样)self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)# 激活函数self.relu = nn.ReLU(inplace=True)def forward(self, x):# 先做通道混洗x = channel_shuffle(x, self.groups)# 主分支计算main_out = self.main_branch(x)# 捷径分支计算(下采样)shortcut_out = self.shortcut_branch(x)# 拼接两个分支(主分支+捷径分支,总通道数=out_channels)out = torch.cat([main_out, shortcut_out], dim=1)return self.relu(out)

关键差异

  • 步长 = 1:用残差连接(输入 + 输出),保持通道数不变;
  • 步长 = 2:用双分支拼接(主分支 + 捷径分支),通道数翻倍,尺寸减半。

步骤 4:搭建 ShuffleNet v1 完整网络

用上面定义的单元,按 ShuffleNet v1 的结构搭建完整网络:

class ShuffleNetV1(nn.Module):def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):super(ShuffleNetV1, self).__init__()self.groups = groups# 基础通道数(根据scale_factor调整模型大小)base_channels = [24, 192, 384, 768]base_channels = [int(c * scale_factor) for c in base_channels]# 1. 初始卷积层self.features = nn.Sequential(nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(base_channels[0]),nn.ReLU(inplace=True),# 2. 初始池化层nn.MaxPool2d(kernel_size=3, stride=2, padding=1))# 3. 三个阶段的ShuffleNet单元# 阶段1:3个步长=1的单元self.stage1 = self._make_stage(in_channels=base_channels[0],out_channels=base_channels[1],groups=groups,num_units=3)# 阶段2:1个步长=2的下采样单元 + 7个步长=1的单元self.stage2 = self._make_stage(in_channels=base_channels[1],out_channels=base_channels[2],groups=groups,num_units=7,is_downsample=True)# 阶段3:1个步长=2的下采样单元 + 3个步长=1的单元self.stage3 = self._make_stage(in_channels=base_channels[2],out_channels=base_channels[3],groups=groups,num_units=3,is_downsample=True)# 4. 全局平均池化self.global_pool = nn.AdaptiveAvgPool2d((1, 1))# 5. 全连接层(输出类别)self.classifier = nn.Linear(base_channels[3], num_classes)def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):"""构建一个阶段的网络(包含多个ShuffleNet单元)"""stage = []# 如果需要下采样,先加一个步长=2的单元if is_downsample:stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))in_channels = out_channels  # 更新输入通道数# 加入num_units个步长=1的单元for _ in range(num_units):stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))return nn.Sequential(*stage)def forward(self, x):x = self.features(x)         # 初始卷积和池化x = self.stage1(x)           # 阶段1x = self.stage2(x)           # 阶段2x = self.stage3(x)           # 阶段3x = self.global_pool(x)      # 全局池化x = x.view(x.size(0), -1)    # 拉平成向量x = self.classifier(x)       # 全连接层输出return x

结构解释

  • 分组数(groups):控制分组卷积的组数(可选 1、2、3、4、8),组数越大,计算量越小(但需平衡精度);
  • 缩放因子(scale_factor):控制通道数(如 0.5、1.0、1.5),用于调整模型大小(0.5 是轻量版,1.5 是高精度版);
  • 阶段设计:每个阶段先通过步长 = 2 的单元下采样(尺寸减半),再用多个步长 = 1 的单元提取特征,逐步提升特征抽象程度。

步骤 5:准备数据(用 CIFAR-10 演示)

ShuffleNet 适合移动端,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:

# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([transforms.Resize(256),  # 缩放为256×256transforms.RandomCrop(224),  # 随机裁剪成224×224transforms.RandomHorizontalFlip(),  # 随机翻转(数据增强)transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)# 批量加载数据(ShuffleNet轻量,batch_size可以设大些)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

步骤 6:初始化模型、损失函数和优化器

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 分组数=8,缩放因子=1.0,输出10类(CIFAR-10)
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)criterion = nn.CrossEntropyLoss()  # 交叉熵损失
# 优化器:推荐用SGD+动量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

步骤 7:训练和测试函数

训练逻辑和之前的模型类似:

def train(model, train_loader, criterion, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()  # 清空梯度output = model(data)   # 模型预测loss = criterion(output, target)  # 计算损失loss.backward()        # 反向传播optimizer.step()       # 更新参数# 打印进度if batch_idx % 100 == 0:print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')def test(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')

步骤 8:开始训练和测试

ShuffleNet 非常轻量,训练速度很快,这里训练 20 轮:

for epoch in range(1, 21):train(model, train_loader, criterion, optimizer, epoch)test(model, test_loader)

在 CIFAR-10 上,ShuffleNet v1(groups=8,scale_factor=1.0)训练充分后准确率能达到 85% 左右,且参数量仅约 2.3 百万(是 MobileNet v1 的 50%,VGG-16 的 1.7%)。

三、完整代码总结

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 1. 实现通道混洗操作
def channel_shuffle(x, groups):batch_size, channels, height, width = x.size()assert channels % groups == 0, "通道数必须是分组数的整数倍"channels_per_group = channels // groups# 核心混洗步骤:分组→转置→展平x = x.view(batch_size, groups, channels_per_group, height, width)x = x.transpose(1, 2).contiguous()x = x.view(batch_size, -1, height, width)return x# 2. 步长=1的ShuffleNet单元(特征提取)
class ShuffleNetUnitV1(nn.Module):def __init__(self, in_channels, out_channels, groups, stride=1):super(ShuffleNetUnitV1, self).__init__()self.stride = strideself.groups = groupsassert out_channels % groups == 0mid_channels = out_channels // 4  # 中间通道数(降维)self.main_branch = nn.Sequential(# 1×1分组卷积(降维)nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# 3×3深度卷积nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),# 1×1分组卷积(升维)nn.Conv2d(mid_channels, out_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(out_channels),)self.relu = nn.ReLU(inplace=True)def forward(self, x):if self.stride == 1:x = channel_shuffle(x, self.groups)  # 通道混洗out = self.main_branch(x)out += x  # 残差连接return self.relu(out)else:x = channel_shuffle(x, self.groups)out = self.main_branch(x)return out# 3. 步长=2的ShuffleNet单元(下采样)
class ShuffleNetDownUnitV1(nn.Module):def __init__(self, in_channels, out_channels, groups):super(ShuffleNetDownUnitV1, self).__init__()self.groups = groupsassert out_channels % groups == 0mid_channels = out_channels // 4self.main_branch = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# 步长=2,下采样nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),nn.Conv2d(mid_channels, out_channels - in_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(out_channels - in_channels),)# 捷径分支:平均池化下采样self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = channel_shuffle(x, self.groups)main_out = self.main_branch(x)shortcut_out = self.shortcut_branch(x)out = torch.cat([main_out, shortcut_out], dim=1)  # 拼接分支return self.relu(out)# 4. 搭建ShuffleNet v1完整网络
class ShuffleNetV1(nn.Module):def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):super(ShuffleNetV1, self).__init__()self.groups = groups# 基础通道数(按缩放因子调整)base_channels = [24, 192, 384, 768]base_channels = [int(c * scale_factor) for c in base_channels]# 初始卷积和池化self.features = nn.Sequential(nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(base_channels[0]),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))# 三个阶段的网络self.stage1 = self._make_stage(in_channels=base_channels[0],out_channels=base_channels[1],groups=groups,num_units=3)self.stage2 = self._make_stage(in_channels=base_channels[1],out_channels=base_channels[2],groups=groups,num_units=7,is_downsample=True)self.stage3 = self._make_stage(in_channels=base_channels[2],out_channels=base_channels[3],groups=groups,num_units=3,is_downsample=True)# 全局池化和分类器self.global_pool = nn.AdaptiveAvgPool2d((1, 1))self.classifier = nn.Linear(base_channels[3], num_classes)def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):stage = []if is_downsample:stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))in_channels = out_channelsfor _ in range(num_units):stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))return nn.Sequential(*stage)def forward(self, x):x = self.features(x)x = self.stage1(x)x = self.stage2(x)x = self.stage3(x)x = self.global_pool(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x# 5. 准备CIFAR-10数据
transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)# 6. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)# 7. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')# 8. 测试函数
def test(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')# 9. 开始训练和测试
for epoch in range(1, 21):train(model, train_loader, criterion, optimizer, epoch)test(model, test_loader)

四、关键知识点回顾

  1. 核心操作:通道混洗通过 "分组→转置→展平" 三步,打破分组卷积的信息隔绝,这是 ShuffleNet 的灵魂;
  2. 单元设计:步长 = 1 的单元用残差连接(输入 + 输出),步长 = 2 的单元用双分支拼接(主分支 + 捷径分支),实现下采样的同时保证特征流通;
  3. 轻量化优势:ShuffleNet v1(groups=8)参数量仅 2.3M,计算量 140MFLOPs,是同等精度模型中资源占用最少的之一;
  4. 灵活配置:通过调整groups(分组数)和scale_factor(缩放因子),可在 "精度" 和 "效率" 之间灵活权衡,满足不同移动端场景需求。

通过这段代码,你能亲手实现这个 "通道混洗大师" 模型,感受轻量化 CNN 在资源受限设备上的强大潜力!

http://www.xdnf.cn/news/1417555.html

相关文章:

  • 广电手机卡到底好不好?
  • 科学研究系统性思维的方法体系:数据收集
  • 【Audio】切换至静音或振动模式时媒体音自动置 0
  • docker安装redis,进入命令窗口基操练习命令
  • 优化括号匹配检查:从Stack到计数器的性能提升
  • MOS管学习
  • Linux 进程状态 — 僵尸进程
  • FDTD_梯度波导学习(1)
  • HOW - 前端团队产出评定方案参考
  • 携程旅行 web 验证码 分析
  • JavaEE 进阶第一期:开启前端入门之旅(上)
  • GitLab 18.3 正式发布,更新多项 DevOps、CI/CD 功能【二】
  • 餐饮门店的小程序怎么做?如何开发餐饮店下单小程序?
  • C++11模板优化大揭秘:让你的代码更简洁、更安全、更高效
  • CICD实战(2) - 使用Arbess+GitLab+SonarQube实现Java项目快速扫描/构建/部署
  • 简单实现Ai音乐suno-api
  • TCP粘包
  • 考研复习-计算机网络-第一章-计算机网络概述
  • keil MDK如何使用第三方软件Keil2Json.exe生成compile_commands.json文件,方便vscode+clangd环境使用
  • 深度解析条件编译:#ifdef与#ifndef的本质区别与应用实践
  • [Android] 京墨 v1.15.2 —— 古诗词文、汉语字典、黄历等查询阅读学习宝典(可离线)
  • MTK-Android13-实现拷贝预置资源到vendor分区下
  • Scikit-learn Python机器学习 - 字典特征提取-DictVectorizer
  • 电脑没加域却能获取到IP地址
  • 基于单片机宠物项圈/宠物防丢失设计
  • 关于命名参数占位符的分析(主要以PHP为例)
  • 设计支持多代WiFi协议的DCF信道访问控制Verilog模块:技术挑战与实现策略
  • Spring Boot配置优化:Tomcat+数据库+缓存+日志,全场景教程
  • c# winform 拼图游戏
  • 预处理——嵌入式学习笔记