当前位置: 首页 > news >正文

残差网络(ResNet)

 残差网络(Residual Network, ResNet)介绍

        残差网络(ResNet)是由微软研究院的何恺明(Kai Ming He)等人于2015年提出的深度卷积神经网络架构,其核心思想是通过残差连接(Skip Connection)解决深层网络的梯度消失/爆炸问题,使得网络可以训练到数百甚至上千层,而不会出现性能退化(Degradation)。ResNet 在图像分类(如 ImageNet)、目标检测(如 Faster R-CNN)、语义分割等任务中表现优异,并成为现代深度学习的基础架构之一。

核心思想

传统深层网络的问题

  • 随着网络深度增加,模型性能会饱和甚至下降(并非过拟合,而是训练误差反而增大)。

  • 梯度消失/爆炸:反向传播时,梯度在深层网络中可能变得极小或极大,导致参数难以更新。

残差连接的解决方案

        传统网络直接学习目标映射 H(x)H(x),而 ResNet 学习残差映射 F(x)=H(x)−xF(x)=H(x)−x,并通过跳跃连接实现:

H(x)=F(x)+xH(x)=F(x)+x

  • 若残差 F(x)=0F(x)=0,则 H(x)=xH(x)=x(恒等映射),确保深层网络至少不会比浅层更差。

 代码

import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import time
import sys
from torch import nn, optim
import torch.nn.functional as F# 定义全局平均池化层
class GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])# 定义残差块
class Residual(nn.Module):# 输入通道数、输出通道数、是否使用1x1卷积核、步长def __init__(self, in_channels, out_channels,use_1x1conv=False, stride=1):super(Residual, self).__init__()# 3x3搭配1步长,特征图大小不变self.conv1 = nn.Conv2d(in_channels, out_channels,kernel_size=3, padding=1, stride=stride)self.conv2 = nn.Conv2d(out_channels, out_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=stride)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)return F.relu(Y+X)          def resnet_block(in_channels, out_channels,num_residuals, first_block=False):if first_block:# 第一个模块的通道数同输入通道数一致assert in_channels == out_channelsblk = []for i in range(num_residuals):if i==0 and not first_block:blk.append(Residual(in_channels, out_channels,use_1x1conv=True, stride=2))else:blk.append(Residual(out_channels, out_channels))return nn.Sequential(*blk)def resnet18(output=10, in_channels=3):"""function:18层残差网络Parameters:in_channels - 输入通道数out_channels - 输出通道数Returns:残差网络Modify:2020-12-24"""net = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))net.add_module("resnet_block1",resnet_block(64, 64, 2, first_block=True))net.add_module("resnet_block2",resnet_block(64, 128, 2))net.add_module("resnet_block3",resnet_block(128, 256, 2))net.add_module("resnet_block4",resnet_block(256, 512, 2))# GlobalAvgPool2d的输出: (Batch, 512, 1, 1)net.add_module("global_avg_pool", GlobalAvgPool2d())net.add_module("fc",nn.Sequential(nn.Flatten(),nn.Linear(512, output)))return netdef load_data_fashion_mnist(batch_size, resize=None):"""function:将fashion mnist数据集划分为小批量样本Parameters:batch_size - 小批量样本的大小(int)resize - 对图像的维度进行扩大Returns:train_iter - 训练集样本划分为最小批的结果test_iter - 测试集样本划分为最小批的结果Modify:2020-11-262020-12-10 添加图像维度变化"""# 存储图像处理流程trans = []if resize:trans.append(transforms.Resize(size=resize))trans.append(transforms.ToTensor())transform = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root='data/FashionMNIST',train=True,download=True,transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root='data/FashionMNIST',train=False,download=True,transform=transform)if sys.platform.startswith('win'):# 0表示不用额外的进程来加速读取数据num_workers = 0else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_workers)return train_iter, test_iterdef evaluate_accuracy(data_iter, net, device=None):"""function:计算多分类模型预测结果的准确率Parameters:data_iter - 样本划分为最小批的结果net - 定义的网络device - 指定计算在GPU或者CPU上进行Returns:准确率计算结果Modify:2020-11-302020-12-03 增加模型训练模型和推理模式的判别2020-12-10 增加指定运行计算位置的方法"""if device is None and isinstance(net, torch.nn.Module):# 如果没指定device就使用net的devicedevice = next(net.parameters()).deviceacc_sum, n = 0.0, 0with torch.no_grad():for X, y in data_iter:if isinstance(net, torch.nn.Module):# 评估模式, 这会关闭dropoutnet.eval()# .cpu()保证可以进行数值加减acc_sum += (net(X.to(device)).argmax(dim=1) ==y.to(device)).float().sum().cpu().item()# 改回训练模式net.train()# 自定义的模型, 2.13节之后不会用到, 不考虑GPUelse:if ('is_training' in net.__code__.co_varnames):# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) ==y).float().sum().item()else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum / ndef train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):"""function:利用softmax回归模型对图像进行分类识别Parameters:net - 定义的网络train_iter - 训练集样本划分为最小批的结果test_iter - 测试集样本划分为最小批的结果num_epochs - 迭代次数batch_size - 最小批大小optimizer - 优化器device - 指定计算在GPU或者CPU上进行Returns:Modify:2020-12-10"""# 将模型加载到指定运算器中net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count = 0.0, 0.0, 0, 0start = time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)# 梯度清零optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net, device=device)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, \time %.1f sec' % (epoch+1, train_l_sum/batch_count,train_acc_sum/n, test_acc,time.time()-start))device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = resnet18(output=10, in_channels=1)
batch_size = 256
train_iter, test_iter =load_data_fashion_mnist(batch_size)
lr, num_epochs = 0.001, 2
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size,optimizer, device, num_epochs)torch.save(net.state_dict(), 'ResNet18.params')

http://www.xdnf.cn/news/411751.html

相关文章:

  • c/c++爬虫总结
  • docker使用过程中遇到概念问题
  • 线程的让位(Yield)
  • 修改linux同步时间
  • 潘大水库介绍
  • object的常用方法
  • MAC-OS X 命令行设置IP、掩码、网关、DNS服务器地址
  • 5月12日信息差
  • 为什么 cout<<“中文你好“ 能正常输出中文
  • Django 项目的 models 目录中,__init__.py 文件的作用
  • [ linux-系统 ] 自动化构建工具makefile
  • Python实例题:pygame开发打飞机游戏
  • 防爆手机与普通手机有什么区别
  • WHAT - 《成为技术领导者》思考题(第六章)
  • 大模型的Lora如何训练?
  • PH热榜 | 2025-05-12
  • 5月12日星期一今日早报简报微语报早读
  • 养生:通往健康生活的桥梁
  • 迁移 Visual Studio Code 设置和扩展到 VSCodium
  • 多模态大语言模型arxiv论文略读(七十)
  • 背单词软件开发英语App提分宝系统源码,河南数匠软件开发
  • 深入解析MySQL联合查询(UNION):案例与实战技巧
  • MySQL全量、增量与恢复
  • 如何有效追踪需求的实现情况
  • 常见提示词攻击方法和防御手段——提示词泄露
  • Flutter - UIKit开发相关指南 - 控制器,主题,表单
  • LTE信道估计MSEBER仿真-块状导频
  • 排查服务器内存空间预警思路
  • vLLM中paged attention算子分析
  • 防止网页被爬取的方法与第三方用户行为检测组件分析