深度学习篇---ShuffleNet网络结构
在 PyTorch 中实现 ShuffleNet(以经典的 ShuffleNet v1 为例),核心是实现它的两个 "招牌动作"——分组卷积和通道混洗。我们从最基础的模块开始,一步步搭建,确保你能理解每个操作的作用。
一、先明确 ShuffleNet v1 的核心结构
ShuffleNet v1 的结构可以概括为:
输入(224×224彩色图) →
初始卷积层 → 初始池化层 →
3个阶段的ShuffleNet单元(每个阶段包含步长=2的下采样单元+多个步长=1的特征提取单元) →
全局平均池化 → 全连接层(输出1000类)
其中,ShuffleNet 单元是核心,分为 "步长 = 1"(保持尺寸)和 "步长 = 2"(下采样)两种。
二、PyTorch 实现 ShuffleNet v1 的步骤
步骤 1:导入必要的库
和之前实现其他 CNN 一样,先准备好工具:
import torch # 核心库
import torch.nn as nn # 神经网络层
import torch.optim as optim # 优化器
from torch.utils.data import DataLoader # 数据加载器
from torchvision import datasets, transforms # 图像数据处理
步骤 2:实现核心操作 —— 通道混洗(Channel Shuffle)
通道混洗是 ShuffleNet 的标志性创新,用于解决分组卷积的 "信息隔绝" 问题。我们先实现这个操作:
def channel_shuffle(x, groups):"""通道混洗操作:将分组卷积后的通道重新打乱分配x: 输入特征图,形状为(batch_size, channels, height, width)groups: 分组数量"""batch_size, channels, height, width = x.size()# 1. 确保通道数能被分组数整除(ShuffleNet的设计要求)assert channels % groups == 0, "通道数必须是分组数的整数倍"channels_per_group = channels // groups # 每组的通道数# 2. 通道混洗的核心步骤:# 拆分成 (batch_size, groups, channels_per_group, height, width)x = x.view(batch_size, groups, channels_per_group, height, width)# 交换groups和channels_per_group维度 → (batch_size, channels_per_group, groups, height, width)x = x.transpose(1, 2).contiguous()# 重新展平通道维度 → (batch_size, channels, height, width)x = x.view(batch_size, -1, height, width)return x
通俗解释:
假设输入是 8 组(groups=8),每组 32 通道(共 256 通道),通道混洗会把 "8 组 ×32 通道" 变成 "32 组 ×8 通道",让新的每组通道都包含原来 8 组的信息,打破分组间的隔绝。
步骤 3:实现 ShuffleNet 的核心单元
ShuffleNet 有两种单元:步长 = 1 的单元(特征提取,尺寸不变)和步长 = 2 的单元(下采样,尺寸减半)。
3.1 步长 = 1 的 ShuffleNet 单元(特征提取)
class ShuffleNetUnitV1(nn.Module):def __init__(self, in_channels, out_channels, groups, stride=1):super(ShuffleNetUnitV1, self).__init__()self.stride = strideself.groups = groups# 确保输出通道数是分组数的整数倍assert out_channels % groups == 0mid_channels = out_channels // 4 # 中间通道数(降维用,减少计算量)# 主分支:1×1分组卷积 → 3×3深度卷积 → 1×1分组卷积self.main_branch = nn.Sequential(# 1×1分组卷积(降维)nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# 3×3深度卷积(每个通道单独卷积,步长控制是否下采样)nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),# 1×1分组卷积(升维)nn.Conv2d(mid_channels, out_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(out_channels),)# 激活函数(放在单元最后)self.relu = nn.ReLU(inplace=True)def forward(self, x):# 步长=1时,使用残差连接(输入直接加输出)if self.stride == 1:# 先做通道混洗(关键!让分组卷积的信息流通)x = channel_shuffle(x, self.groups)# 主分支计算out = self.main_branch(x)# 残差连接:输入+输出out += xreturn self.relu(out)# 步长=2时的处理(下采样,后面单独实现)else:# 先做通道混洗x = channel_shuffle(x, self.groups)# 主分支计算out = self.main_branch(x)return out
3.2 步长 = 2 的 ShuffleNet 单元(下采样)
步长 = 2 的单元需要下采样(尺寸减半),因此用 "双分支" 设计:
class ShuffleNetDownUnitV1(nn.Module):def __init__(self, in_channels, out_channels, groups):super(ShuffleNetDownUnitV1, self).__init__()self.groups = groups# 确保输出通道数是分组数的整数倍assert out_channels % groups == 0mid_channels = out_channels // 4 # 中间通道数# 主分支:1×1分组卷积 → 3×3深度卷积(步长2) → 1×1分组卷积self.main_branch = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# 步长=2,实现下采样(尺寸减半)nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),nn.Conv2d(mid_channels, out_channels - in_channels, kernel_size=1, # 输出通道数=总通道数-捷径分支通道数groups=groups, bias=False),nn.BatchNorm2d(out_channels - in_channels),)# 捷径分支:平均池化(步长2,下采样)self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)# 激活函数self.relu = nn.ReLU(inplace=True)def forward(self, x):# 先做通道混洗x = channel_shuffle(x, self.groups)# 主分支计算main_out = self.main_branch(x)# 捷径分支计算(下采样)shortcut_out = self.shortcut_branch(x)# 拼接两个分支(主分支+捷径分支,总通道数=out_channels)out = torch.cat([main_out, shortcut_out], dim=1)return self.relu(out)
关键差异:
- 步长 = 1:用残差连接(输入 + 输出),保持通道数不变;
- 步长 = 2:用双分支拼接(主分支 + 捷径分支),通道数翻倍,尺寸减半。
步骤 4:搭建 ShuffleNet v1 完整网络
用上面定义的单元,按 ShuffleNet v1 的结构搭建完整网络:
class ShuffleNetV1(nn.Module):def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):super(ShuffleNetV1, self).__init__()self.groups = groups# 基础通道数(根据scale_factor调整模型大小)base_channels = [24, 192, 384, 768]base_channels = [int(c * scale_factor) for c in base_channels]# 1. 初始卷积层self.features = nn.Sequential(nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(base_channels[0]),nn.ReLU(inplace=True),# 2. 初始池化层nn.MaxPool2d(kernel_size=3, stride=2, padding=1))# 3. 三个阶段的ShuffleNet单元# 阶段1:3个步长=1的单元self.stage1 = self._make_stage(in_channels=base_channels[0],out_channels=base_channels[1],groups=groups,num_units=3)# 阶段2:1个步长=2的下采样单元 + 7个步长=1的单元self.stage2 = self._make_stage(in_channels=base_channels[1],out_channels=base_channels[2],groups=groups,num_units=7,is_downsample=True)# 阶段3:1个步长=2的下采样单元 + 3个步长=1的单元self.stage3 = self._make_stage(in_channels=base_channels[2],out_channels=base_channels[3],groups=groups,num_units=3,is_downsample=True)# 4. 全局平均池化self.global_pool = nn.AdaptiveAvgPool2d((1, 1))# 5. 全连接层(输出类别)self.classifier = nn.Linear(base_channels[3], num_classes)def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):"""构建一个阶段的网络(包含多个ShuffleNet单元)"""stage = []# 如果需要下采样,先加一个步长=2的单元if is_downsample:stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))in_channels = out_channels # 更新输入通道数# 加入num_units个步长=1的单元for _ in range(num_units):stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))return nn.Sequential(*stage)def forward(self, x):x = self.features(x) # 初始卷积和池化x = self.stage1(x) # 阶段1x = self.stage2(x) # 阶段2x = self.stage3(x) # 阶段3x = self.global_pool(x) # 全局池化x = x.view(x.size(0), -1) # 拉平成向量x = self.classifier(x) # 全连接层输出return x
结构解释:
- 分组数(groups):控制分组卷积的组数(可选 1、2、3、4、8),组数越大,计算量越小(但需平衡精度);
- 缩放因子(scale_factor):控制通道数(如 0.5、1.0、1.5),用于调整模型大小(0.5 是轻量版,1.5 是高精度版);
- 阶段设计:每个阶段先通过步长 = 2 的单元下采样(尺寸减半),再用多个步长 = 1 的单元提取特征,逐步提升特征抽象程度。
步骤 5:准备数据(用 CIFAR-10 演示)
ShuffleNet 适合移动端,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:
# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([transforms.Resize(256), # 缩放为256×256transforms.RandomCrop(224), # 随机裁剪成224×224transforms.RandomHorizontalFlip(), # 随机翻转(数据增强)transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)# 批量加载数据(ShuffleNet轻量,batch_size可以设大些)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)
步骤 6:初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 分组数=8,缩放因子=1.0,输出10类(CIFAR-10)
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)criterion = nn.CrossEntropyLoss() # 交叉熵损失
# 优化器:推荐用SGD+动量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
步骤 7:训练和测试函数
训练逻辑和之前的模型类似:
def train(model, train_loader, criterion, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad() # 清空梯度output = model(data) # 模型预测loss = criterion(output, target) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数# 打印进度if batch_idx % 100 == 0:print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')def test(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')
步骤 8:开始训练和测试
ShuffleNet 非常轻量,训练速度很快,这里训练 20 轮:
for epoch in range(1, 21):train(model, train_loader, criterion, optimizer, epoch)test(model, test_loader)
在 CIFAR-10 上,ShuffleNet v1(groups=8,scale_factor=1.0)训练充分后准确率能达到 85% 左右,且参数量仅约 2.3 百万(是 MobileNet v1 的 50%,VGG-16 的 1.7%)。
三、完整代码总结
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 1. 实现通道混洗操作
def channel_shuffle(x, groups):batch_size, channels, height, width = x.size()assert channels % groups == 0, "通道数必须是分组数的整数倍"channels_per_group = channels // groups# 核心混洗步骤:分组→转置→展平x = x.view(batch_size, groups, channels_per_group, height, width)x = x.transpose(1, 2).contiguous()x = x.view(batch_size, -1, height, width)return x# 2. 步长=1的ShuffleNet单元(特征提取)
class ShuffleNetUnitV1(nn.Module):def __init__(self, in_channels, out_channels, groups, stride=1):super(ShuffleNetUnitV1, self).__init__()self.stride = strideself.groups = groupsassert out_channels % groups == 0mid_channels = out_channels // 4 # 中间通道数(降维)self.main_branch = nn.Sequential(# 1×1分组卷积(降维)nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# 3×3深度卷积nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),# 1×1分组卷积(升维)nn.Conv2d(mid_channels, out_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(out_channels),)self.relu = nn.ReLU(inplace=True)def forward(self, x):if self.stride == 1:x = channel_shuffle(x, self.groups) # 通道混洗out = self.main_branch(x)out += x # 残差连接return self.relu(out)else:x = channel_shuffle(x, self.groups)out = self.main_branch(x)return out# 3. 步长=2的ShuffleNet单元(下采样)
class ShuffleNetDownUnitV1(nn.Module):def __init__(self, in_channels, out_channels, groups):super(ShuffleNetDownUnitV1, self).__init__()self.groups = groupsassert out_channels % groups == 0mid_channels = out_channels // 4self.main_branch = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),# 步长=2,下采样nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False),nn.BatchNorm2d(mid_channels),nn.Conv2d(mid_channels, out_channels - in_channels, kernel_size=1, groups=groups, bias=False),nn.BatchNorm2d(out_channels - in_channels),)# 捷径分支:平均池化下采样self.shortcut_branch = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = channel_shuffle(x, self.groups)main_out = self.main_branch(x)shortcut_out = self.shortcut_branch(x)out = torch.cat([main_out, shortcut_out], dim=1) # 拼接分支return self.relu(out)# 4. 搭建ShuffleNet v1完整网络
class ShuffleNetV1(nn.Module):def __init__(self, num_classes=1000, groups=8, scale_factor=1.0):super(ShuffleNetV1, self).__init__()self.groups = groups# 基础通道数(按缩放因子调整)base_channels = [24, 192, 384, 768]base_channels = [int(c * scale_factor) for c in base_channels]# 初始卷积和池化self.features = nn.Sequential(nn.Conv2d(3, base_channels[0], kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(base_channels[0]),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))# 三个阶段的网络self.stage1 = self._make_stage(in_channels=base_channels[0],out_channels=base_channels[1],groups=groups,num_units=3)self.stage2 = self._make_stage(in_channels=base_channels[1],out_channels=base_channels[2],groups=groups,num_units=7,is_downsample=True)self.stage3 = self._make_stage(in_channels=base_channels[2],out_channels=base_channels[3],groups=groups,num_units=3,is_downsample=True)# 全局池化和分类器self.global_pool = nn.AdaptiveAvgPool2d((1, 1))self.classifier = nn.Linear(base_channels[3], num_classes)def _make_stage(self, in_channels, out_channels, groups, num_units, is_downsample=False):stage = []if is_downsample:stage.append(ShuffleNetDownUnitV1(in_channels, out_channels, groups))in_channels = out_channelsfor _ in range(num_units):stage.append(ShuffleNetUnitV1(in_channels, out_channels, groups))return nn.Sequential(*stage)def forward(self, x):x = self.features(x)x = self.stage1(x)x = self.stage2(x)x = self.stage3(x)x = self.global_pool(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x# 5. 准备CIFAR-10数据
transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)# 6. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ShuffleNetV1(num_classes=10, groups=8, scale_factor=1.0).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)# 7. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')# 8. 测试函数
def test(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')# 9. 开始训练和测试
for epoch in range(1, 21):train(model, train_loader, criterion, optimizer, epoch)test(model, test_loader)
四、关键知识点回顾
- 核心操作:通道混洗通过 "分组→转置→展平" 三步,打破分组卷积的信息隔绝,这是 ShuffleNet 的灵魂;
- 单元设计:步长 = 1 的单元用残差连接(输入 + 输出),步长 = 2 的单元用双分支拼接(主分支 + 捷径分支),实现下采样的同时保证特征流通;
- 轻量化优势:ShuffleNet v1(groups=8)参数量仅 2.3M,计算量 140MFLOPs,是同等精度模型中资源占用最少的之一;
- 灵活配置:通过调整
groups
(分组数)和scale_factor
(缩放因子),可在 "精度" 和 "效率" 之间灵活权衡,满足不同移动端场景需求。
通过这段代码,你能亲手实现这个 "通道混洗大师" 模型,感受轻量化 CNN 在资源受限设备上的强大潜力!