卷积神经网络实战(3)
继续说上一次模型定义,额外说一下就是padding='same'可以让它自动计算需要几圈padding。'valid'就是padding = 0。当步长为2的时候,same无法使用。
class CNN(nn.Module):def __init__(self, activation="relu"):super(CNN, self).__init__()self.activation = F.relu if activation == "relu" else F.selu#输入通道数,图片是灰度图,所以是1,图片是彩色图,就是3,输出通道数,就是卷积核的个数(32,1,28,28)#输入x(32,1,28,28) 输出x(32,32,28,28)self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)#输入x(32,32,28,28) 输出x(32,32,28,28)self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(2, 2) #池化不能够改变通道数,池化核大小为2(2*2),步长为2 (28-2)//2+1=14self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)self.flatten = nn.Flatten()# input shape is (28, 28, 1) so the fc1 layer in_features is 128 * 3 * 3self.fc1 = nn.Linear(128 * 3 * 3, 128)self.fc2 = nn.Linear(128, 10) #输出尺寸(32,10)self.init_weights()def init_weights(self):"""使用 xavier 均匀分布来初始化全连接层、卷积层的权重 W"""for m in self.modules():if isinstance(m, (nn.Linear, nn.Conv2d)): nn.init.xavier_uniform_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):act = self.activationx = self.pool(act(self.conv2(act(self.conv1(x))))) # 1 * 28 * 28 -> 32 * 14 * 14# print(x.shape)x = self.pool(act(self.conv4(act(self.conv3(x))))) # 32 * 14 * 14 -> 64 * 7 * 7# print(x.shape)x = self.pool(act(self.conv6(act(self.conv5(x))))) # 64 * 7 * 7 -> 128 * 3 * 3# print(x.shape)x = self.flatten(x) # 128 * 3 * 3 ->1152x = act(self.fc1(x)) # 1152 -> 128x = self.fc2(x) # 128 -> 10return x
模型中init_weights的作用是初始化模型的可训练参数。对于初始化权重,使用了Xavier均匀分布初始化,据输入和输出的维度动态调整权重的范围,使得前向传播时输出的方差保持一致,反向传播时梯度的方差也保持一致。初始化偏置则设置为0。
-
PyTorch 默认会对
nn.Linear
和nn.Conv2d
使用 Kaiming均匀分布初始化(针对ReLU优化)。 -
显式调用
init_weights
是为了:-
统一所有层的初始化策略(例如本代码中强制使用 Xavier)。
-
适配特定激活函数(如代码支持
SELU
时,Xavier 更合适)。 -
显式控制初始化过程,避免框架默认行为的变化。
-
下一节放出全部代码。