当前位置: 首页 > java >正文

卷积神经网络实战(3)

继续说上一次模型定义,额外说一下就是padding='same'可以让它自动计算需要几圈padding。'valid'就是padding = 0。当步长为2的时候,same无法使用。

class CNN(nn.Module):def __init__(self, activation="relu"):super(CNN, self).__init__()self.activation = F.relu if activation == "relu" else F.selu#输入通道数,图片是灰度图,所以是1,图片是彩色图,就是3,输出通道数,就是卷积核的个数(32,1,28,28)#输入x(32,1,28,28) 输出x(32,32,28,28)self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)#输入x(32,32,28,28) 输出x(32,32,28,28)self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(2, 2) #池化不能够改变通道数,池化核大小为2(2*2),步长为2  (28-2)//2+1=14self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)self.flatten = nn.Flatten()# input shape is (28, 28, 1) so the fc1 layer in_features is 128 * 3 * 3self.fc1 = nn.Linear(128 * 3 * 3, 128)self.fc2 = nn.Linear(128, 10) #输出尺寸(32,10)self.init_weights()def init_weights(self):"""使用 xavier 均匀分布来初始化全连接层、卷积层的权重 W"""for m in self.modules():if isinstance(m, (nn.Linear, nn.Conv2d)): nn.init.xavier_uniform_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):act = self.activationx = self.pool(act(self.conv2(act(self.conv1(x))))) # 1 * 28 * 28 -> 32 * 14 * 14# print(x.shape)x = self.pool(act(self.conv4(act(self.conv3(x))))) # 32 * 14 * 14 -> 64 * 7 * 7# print(x.shape)x = self.pool(act(self.conv6(act(self.conv5(x))))) # 64 * 7 * 7 -> 128 * 3 * 3# print(x.shape)x = self.flatten(x) # 128 * 3 * 3 ->1152x = act(self.fc1(x)) # 1152 -> 128x = self.fc2(x) # 128 -> 10return x

模型中init_weights的作用是初始化模型的可训练参数。对于初始化权重,使用了Xavier均匀分布初始化,据输入和输出的维度动态调整权重的范围,使得前向传播时输出的方差保持一致,反向传播时梯度的方差也保持一致。初始化偏置则设置为0。

  • PyTorch 默认会对 nn.Linear 和 nn.Conv2d 使用 Kaiming均匀分布初始化(针对ReLU优化)。

  • 显式调用 init_weights 是为了:

    1. 统一所有层的初始化策略(例如本代码中强制使用 Xavier)。

    2. 适配特定激活函数(如代码支持 SELU 时,Xavier 更合适)。

    3. 显式控制初始化过程,避免框架默认行为的变化。

  下一节放出全部代码。

http://www.xdnf.cn/news/4346.html

相关文章:

  • 【基础】Python包管理工具uv使用全教程
  • Java日期格式化方法总结
  • DApp 开发:开启去中心化应用新时代
  • Spring事务和事务传播机制
  • C语言| 递归和循环的优缺点
  • 塔能水泵节能方案:精准驱动工厂能耗优化
  • 展锐Android13禁止用户使用超级省电
  • 新一代Python专业编译器Nuitka简介
  • ROS2:自定义接口文件(无废话)
  • 多模态理论知识
  • 二叉树与堆排序(概念|遍历|实现)
  • python酒店在线预定管理系统-酒店客房管理系统-快捷酒店入住系统
  • 【Linux系统】vim编辑器的使用
  • FoMo 数据集是一个专注于机器人在季节性积雪变化环境中的导航数据集,记录了不同季节(无雪、浅雪、深雪)下的传感器数据和轨迹信息。
  • C语言编程--递归程序--求数组的最大元素值
  • 油气地震资料信号处理中的NMO(正常时差校正)
  • 【网络篇】传输层TCP协议的确认应答,超时重传机制
  • IT咨询——企业数据资产怎样评估
  • 满分PPT | 基于数据运营的新型智慧城市实践与思考智慧城市数据中台解决方案智能建筑与智慧城市建设方案
  • 基于nacos实现动态线程池设计与实践:告别固定配置,拥抱弹性调度
  • LabVIEW与 IMAQ Vision 机器视觉应用
  • C++类与对象基础续
  • 15.命令模式:思考与解读
  • 毫米波雷达原理(最通俗的解释)
  • MATLAB程序演示与编程思路,相对导航,四个小车的形式,使用集中式扩展卡尔曼滤波(fullyCN-EKF)
  • go 编译报错:build constraints exclude all Go files
  • Python使用爬虫ip抓取热点新闻
  • autojspro怎么免费用
  • 【原创分享】魔音变声器内含超多语音包实时变声
  • C#中从本地(两个路径文件夹)中实时拿图显示到窗口中并接收(两个tcp发送的信号)转为字符串显示在窗体中实现检测可视化