当前位置: 首页 > news >正文

深度学习篇---MobileNet网络结构

在 PyTorch 中实现 MobileNet(以经典的 MobileNet v1 为例)并不复杂,核心是实现它的 "杀手锏"—— 深度可分离卷积。我们一步步来,从最基础的模块开始搭建,保证你能看懂每一行代码的作用。

一、先明确 MobileNet v1 的核心结构

MobileNet v1 的结构可以概括为:

输入(224×224彩色图) → 
标准卷积层 → 
13个深度可分离卷积块 → 
全局平均池化 → 
全连接层(输出1000类)

其中,深度可分离卷积块是核心,每个块由 "深度卷积 + 逐点卷积 + ReLU + 池化 (可选)" 组成。

二、PyTorch 实现 MobileNet v1 的步骤

步骤 1:导入必要的库

和之前实现其他 CNN 一样,先准备好工具:

import torch  # 核心库
import torch.nn as nn  # 神经网络层
import torch.optim as optim  # 优化器
from torch.utils.data import DataLoader  # 数据加载器
from torchvision import datasets, transforms  # 图像数据处理

步骤 2:实现核心模块 —— 深度可分离卷积块

MobileNet 的精髓在于 "深度可分离卷积",我们先定义这个基础模块:

class DepthwiseSeparableConv(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(DepthwiseSeparableConv, self).__init__()# 1. 深度卷积:每个输入通道用1个3×3卷积核单独处理self.depthwise = nn.Conv2d(in_channels=in_channels,out_channels=in_channels,  # 输出通道数=输入通道数(每个通道单独处理)kernel_size=3,stride=stride,padding=1,groups=in_channels  # 分组卷积:groups=in_channels表示每个通道单独卷积)# 2. 逐点卷积:用1×1卷积核融合通道(把in_channels→out_channels)self.pointwise = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,  # 1×1卷积核stride=1,padding=0)# 激活函数和批归一化(提升训练稳定性)self.bn1 = nn.BatchNorm2d(in_channels)self.bn2 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):# 深度卷积→批归一化→ReLUx = self.depthwise(x)x = self.bn1(x)x = self.relu(x)# 逐点卷积→批归一化→ReLUx = self.pointwise(x)x = self.bn2(x)x = self.relu(x)return x

关键解释

  • 深度卷积:通过groups=in_channels实现 "每个输入通道用自己的卷积核",比如 3 通道输入就用 3 个卷积核,每个处理 1 个通道;
  • 逐点卷积:1×1 卷积核的作用是 "融合通道",把深度卷积输出的in_channels个通道,压缩或扩展到out_channels个;
  • 批归一化(BN):MobileNet 中大量使用 BN,让训练更稳定,收敛更快。

步骤 3:搭建 MobileNet v1 完整网络

用上面定义的深度可分离卷积块,按 MobileNet v1 的结构搭建完整网络:

结构解释

  • 宽度乘法器:通过alpha参数控制所有通道数,比如alpha=0.5时通道数减半,模型更轻量;
  • 步长设计:每隔几个块用stride=2的卷积,让特征图尺寸逐步减半(224→112→56→28→14→7);
  • 全局平均池化:替代了传统 CNN 的全连接层前的拉平操作,进一步减少参数。

步骤 4:准备数据(用 CIFAR-10 演示)

MobileNet 适合移动设备,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:

python

运行

# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([transforms.Resize(256),  # 缩放为256×256transforms.RandomCrop(224),  # 随机裁剪成224×224transforms.RandomHorizontalFlip(),  # 随机翻转(数据增强)transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)# 批量加载数据(MobileNet轻量,batch_size可以设大些)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

步骤 5:初始化模型、损失函数和优化器

python

运行

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 用width_multiplier=1.0的标准模型,输出10类(CIFAR-10)
model = MobileNetV1(num_classes=10, width_multiplier=1.0).to(device)criterion = nn.CrossEntropyLoss()  # 交叉熵损失
# 优化器:MobileNet推荐用RMSprop,学习率0.001
optimizer = optim.RMSprop(model.parameters(), lr=0.001, momentum=0.9)

步骤 6:训练和测试函数

和之前的模型类似,训练逻辑如下:

def train(model, train_loader, criterion, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()  # 清空梯度output = model(data)   # 模型预测loss = criterion(output, target)  # 计算损失loss.backward()        # 反向传播optimizer.step()       # 更新参数# 打印进度if batch_idx % 100 == 0:print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')def test(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')

步骤 7:开始训练和测试

MobileNet 轻量,训练速度比 VGG 快很多,这里训练 20 轮:

for epoch in range(1, 21):train(model, train_loader, criterion, optimizer, epoch)test(model, test_loader)

在 CIFAR-10 上,MobileNet v1(width_multiplier=1.0)训练充分后准确率能达到 85% 左右,且训练速度比 VGG 快 5-10 倍。

三、完整代码总结

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 1. 定义深度可分离卷积块
class DepthwiseSeparableConv(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(DepthwiseSeparableConv, self).__init__()# 深度卷积(每个通道单独卷积)self.depthwise = nn.Conv2d(in_channels=in_channels,out_channels=in_channels,kernel_size=3,stride=stride,padding=1,groups=in_channels  # 分组卷积实现深度卷积)# 逐点卷积(1×1卷积融合通道)self.pointwise = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=1,padding=0)# 批归一化和激活函数self.bn1 = nn.BatchNorm2d(in_channels)self.bn2 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.depthwise(x)x = self.bn1(x)x = self.relu(x)x = self.pointwise(x)x = self.bn2(x)x = self.relu(x)return x# 2. 定义MobileNet v1完整网络
class MobileNetV1(nn.Module):def __init__(self, num_classes=1000, width_multiplier=1.0):super(MobileNetV1, self).__init__()alpha = width_multiplier  # 宽度乘法器# 通道数(乘以宽度乘法器)channels = [int(32 * alpha),int(64 * alpha),int(128 * alpha),int(256 * alpha),int(512 * alpha),int(1024 * alpha)]# 特征提取部分self.features = nn.Sequential(# 初始标准卷积nn.Conv2d(3, channels[0], kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(channels[0]),nn.ReLU(inplace=True),# 13个深度可分离卷积块DepthwiseSeparableConv(channels[0], channels[1], stride=1),DepthwiseSeparableConv(channels[1], channels[2], stride=2),DepthwiseSeparableConv(channels[2], channels[2], stride=1),DepthwiseSeparableConv(channels[2], channels[3], stride=2),DepthwiseSeparableConv(channels[3], channels[3], stride=1),DepthwiseSeparableConv(channels[3], channels[4], stride=2),DepthwiseSeparableConv(channels[4], channels[4], stride=1),DepthwiseSeparableConv(channels[4], channels[4], stride=1),DepthwiseSeparableConv(channels[4], channels[4], stride=1),DepthwiseSeparableConv(channels[4], channels[4], stride=1),DepthwiseSeparableConv(channels[4], channels[5], stride=2),DepthwiseSeparableConv(channels[5], channels[5], stride=1),# 全局平均池化nn.AdaptiveAvgPool2d((1, 1)))# 分类部分self.classifier = nn.Linear(channels[5], num_classes)def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)  # 拉平特征x = self.classifier(x)return x# 3. 准备CIFAR-10数据
transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)# 4. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MobileNetV1(num_classes=10, width_multiplier=1.0).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001, momentum=0.9)# 5. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')# 6. 测试函数
def test(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')# 7. 开始训练和测试
for epoch in range(1, 21):train(model, train_loader, criterion, optimizer, epoch)test(model, test_loader)

四、关键知识点回顾

  1. 核心模块:深度可分离卷积 = 深度卷积(分组卷积实现)+ 逐点卷积(1×1 卷积),这是 MobileNet 轻量化的关键;
  2. 宽度乘法器:通过width_multiplier参数灵活控制模型大小,比如alpha=0.5时模型参数量减少到原来的 25%;
  3. 优势体现:MobileNet v1 参数量约 420 万(仅为 VGG-16 的 3%),计算量约 0.58 亿次乘法(仅为 VGG-16 的 4%),但精度下降很少;
  4. 适用场景:代码可直接用于移动端部署,只需将训练好的模型转换为 ONNX 或 TensorRT 格式,就能在手机、嵌入式设备上高效运行。

通过这段代码,你能亲手实现这个 "移动端 AI 利器",感受轻量化 CNN 的强大魅力!

http://www.xdnf.cn/news/1412101.html

相关文章:

  • 五分钟聊一聊AQS源码
  • globals() 小技巧
  • 仅有一张Fig的8分文章 胞外囊泡lncRNA+ CT 多模态融合模型,AUC 最高达 94.8%
  • 【LeetCode修行之路】算法的时间和空间复杂度分析
  • 大数据毕业设计选题推荐-基于大数据的大气和海洋动力学数据分析与可视化系统-Spark-Hadoop-Bigdata
  • ESP32C3 系列实战(1) --点亮小灯
  • Wi-Fi技术——物理层技术
  • 使用Cadence工具完成数模混合设计流程简介
  • LangChain核心抽象:Runnable接口深度解析
  • leetcode_48 旋转图像
  • FFMPEG学习任务
  • 第 14 篇:K-Means与聚类思维——当AI在没有“标准答案”的世界里寻宝
  • 【C2000】C2000的硬件设计指导与几点意见
  • 开源知识抽取框架 推荐
  • 京东获取商品评论指南,实时关注用户反馈
  • 官方 API 与网络爬虫的技术特性对比及选型分析
  • Unity学习----【数据持久化】二进制存储(三)--文件夹操作
  • OpenStack 01:介绍
  • 暄桐林曦老师关于静坐常见问题的QA
  • 基于GA遗传优化的双向LSTM融合多头注意力(BiLSTM-MATT)时间序列预测算法matlab仿真
  • windows系统中的docker,xinference直接运行在容器目录和持载在宿主机目录中的区别
  • isat将标签转化为labelme格式后,labelme打不开的解决方案
  • MyBatis 黑马 辅助配置,数据库连接池
  • 柔性数组与不定长数据
  • 【秋招笔试】2025.08.31饿了么秋招笔试题
  • SPMTE 2022概述
  • 线程池常见面试问答
  • 一次解决 Elasticsearch 两大难题: 掌握去重和深分页的最佳实践
  • Day19_【机器学习—线性回归 (1)】
  • PerfectSquares.java