PyTorch中nn.Module详解和综合代码示例
在 PyTorch 中,nn.Module
是神经网络中最核心的基类,用于构建所有模型。理解并熟练使用 nn.Module
是掌握 PyTorch 的关键。
一、什么是 nn.Module
nn.Module
是 PyTorch 中所有神经网络模块的基类。可以把它看作是“神经网络的容器”,它封装了以下几件事:
- 网络层(如 Linear、Conv2d 等)
- 前向传播逻辑(
forward
函数) - 模型参数(自动注册并可训练)
- 可嵌套(可以包含多个子模块)
- 便捷的模型保存 / 加载等工具函数
二、基础用法
2.1 自定义模型类
import torch
import torch.nn as nnclass MyNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x
2.2 实例化与调用
model = MyNet()
x = torch.randn(32, 784) # batch_size = 32
output = model(x) # 自动调用 forward
三、构造方法详解
3.1 __init__()
- 定义子模块、层等结构。
- 例如
self.conv1 = nn.Conv2d(...)
会被自动注册为模型参数。
3.2 forward()
- 定义前向传播逻辑。
- 不能手动调用,应使用
model(x)
形式。
四、常见模块层
模块名 | 作用 | 示例 |
---|---|---|
nn.Linear | 全连接层 | nn.Linear(128, 64) |
nn.Conv2d | 卷积层 | nn.Conv2d(3, 16, 3) |
nn.ReLU | 激活函数 | nn.ReLU() |
nn.Sigmoid | 激活函数 | nn.Sigmoid() |
nn.BatchNorm2d | 批归一化 | nn.BatchNorm2d(16) |
nn.Dropout | Dropout 层 | nn.Dropout(0.5) |
nn.LSTM | LSTM 层 | nn.LSTM(10, 20) |
nn.Sequential | 层的顺序容器 | 见下文说明 |
五、模型嵌套结构(子模块)
你可以将一个 nn.Module
作为另一个模块的子模块嵌套:
class Block(nn.Module):def __init__(self):super().__init__()self.layer = nn.Sequential(nn.Linear(64, 64),nn.ReLU())def forward(self, x):return self.layer(x)class Net(nn.Module):def __init__(self):super().__init__()self.block1 = Block()self.block2 = Block()self.output = nn.Linear(64, 10)def forward(self, x):x = self.block1(x)x = self.block2(x)return self.output(x)
六、内置方法和属性
方法 / 属性 | 说明 |
---|---|
model.parameters() | 返回所有可训练参数(用于优化器) |
model.named_parameters() | 返回带名字的参数迭代器 |
model.children() | 返回子模块迭代器 |
model.eval() | 设置为评估模式(Dropout、BN失效) |
model.train() | 设置为训练模式 |
model.to(device) | 将模型转移到 GPU/CPU |
model.state_dict() | 获取模型参数字典(保存) |
model.load_state_dict() | 加载模型参数字典 |
七、使用 nn.Sequential
nn.Sequential
是一个顺序容器,可以用来简化网络结构定义:
model = nn.Sequential(nn.Linear(784, 128),nn.ReLU(),nn.Linear(128, 10)
)
等价于手写的自定义 nn.Module
。适合前向传播是线性“流动”的结构。
八、实战完整示例:MNIST 分类网络
class MNISTNet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Flatten(),nn.Linear(28*28, 256),nn.ReLU(),nn.Linear(256, 10))def forward(self, x):return self.net(x)# 实例化模型
model = MNISTNet()
print(model)# 配置训练
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)# 示例训练循环
for epoch in range(10):for images, labels in train_loader:output = model(images)loss = criterion(output, labels)optimizer.zero_grad()loss.backward()optimizer.step()
九、常见陷阱和建议
问题 | 说明 |
---|---|
forward() 不起作用 | 应该使用 model(x) ,而不是手动调用 model.forward(x) |
忘记 super().__init__() | 子模块将不会被注册 |
参数未注册 | 层/模块必须赋值为 self.xxx = ... |
训练/测试模式混淆 | 注意 model.eval() 和 model.train() |
十、总结
项目 | 说明 |
---|---|
__init__() | 定义模型结构(子模块、层) |
forward() | 定义前向传播 |
自动注册参数 | 所有 self.xxx = nn.XXX(...) 都会被追踪 |
嵌套模块 | 支持递归子模块调用 |
便捷方法 | .parameters() 、.to() 、.eval() 等 |
十一、综合示例
以下是基于 PyTorch nn.Module
封装的三种经典深度学习架构(ResNet18、UNet、Transformer)的简洁而完整的实现,适合初学者快速上手。
1、ResNet18 简洁实现(适合图像分类)
import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_planes, planes, stride=1, downsample=None):super().__init__()self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampledef forward(self, x):identity = xif self.downsample:identity = self.downsample(x)out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += identityreturn F.relu(out)class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000):super().__init__()self.in_planes = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.in_planes != planes * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_planes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(planes * block.expansion))layers = [block(self.in_planes, planes, stride, downsample)]self.in_planes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.in_planes, planes))return nn.Sequential(*layers)def forward(self, x):x = self.pool(F.relu(self.bn1(self.conv1(x))))x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x).flatten(1)return self.fc(x)def ResNet18(num_classes=1000):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
2、UNet(适合图像分割)
class UNetBlock(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.ReLU(inplace=True))def forward(self, x):return self.block(x)class UNet(nn.Module):def __init__(self, in_channels=1, out_channels=1):super().__init__()self.enc1 = UNetBlock(in_channels, 64)self.enc2 = UNetBlock(64, 128)self.enc3 = UNetBlock(128, 256)self.enc4 = UNetBlock(256, 512)self.pool = nn.MaxPool2d(2)self.bottleneck = UNetBlock(512, 1024)self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)self.dec4 = UNetBlock(1024, 512)self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)self.dec3 = UNetBlock(512, 256)self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)self.dec2 = UNetBlock(256, 128)self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)self.dec1 = UNetBlock(128, 64)self.final = nn.Conv2d(64, out_channels, kernel_size=1)def forward(self, x):e1 = self.enc1(x)e2 = self.enc2(self.pool(e1))e3 = self.enc3(self.pool(e2))e4 = self.enc4(self.pool(e3))b = self.bottleneck(self.pool(e4))d4 = self.upconv4(b)d4 = self.dec4(torch.cat([d4, e4], dim=1))d3 = self.upconv3(d4)d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.upconv2(d3)d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.upconv1(d2)d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)
3、简化版 Transformer 编码器(适合序列建模)
class TransformerBlock(nn.Module):def __init__(self, embed_dim, heads, ff_hidden_dim, dropout=0.1):super().__init__()self.attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout, batch_first=True)self.ff = nn.Sequential(nn.Linear(embed_dim, ff_hidden_dim),nn.ReLU(),nn.Linear(ff_hidden_dim, embed_dim))self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):attn_out, _ = self.attn(x, x, x, attn_mask=mask)x = self.norm1(x + self.dropout(attn_out))ff_out = self.ff(x)x = self.norm2(x + self.dropout(ff_out))return xclass TransformerEncoder(nn.Module):def __init__(self, vocab_size, embed_dim=512, n_heads=8, ff_dim=2048, num_layers=6, max_len=512):super().__init__()self.embedding = nn.Embedding(vocab_size, embed_dim)self.pos_encoding = self._generate_positional_encoding(max_len, embed_dim)self.layers = nn.ModuleList([TransformerBlock(embed_dim, n_heads, ff_dim)for _ in range(num_layers)])self.dropout = nn.Dropout(0.1)def _generate_positional_encoding(self, max_len, d_model):pos = torch.arange(0, max_len).unsqueeze(1)i = torch.arange(0, d_model, 2)angle_rates = 1 / torch.pow(10000, (i / d_model))pos_enc = torch.zeros(max_len, d_model)pos_enc[:, 0::2] = torch.sin(pos * angle_rates)pos_enc[:, 1::2] = torch.cos(pos * angle_rates)return pos_enc.unsqueeze(0)def forward(self, x):B, T = x.shapex = self.embedding(x) + self.pos_encoding[:, :T].to(x.device)x = self.dropout(x)for layer in self.layers:x = layer(x)return x
4、 总结对比
模型类型 | 场景 | 特点 |
---|---|---|
ResNet18 | 图像分类 | 深残差网络结构,适合迁移学习 |
UNet | 图像分割 | 对称结构,编码 + 解码 + skip |
Transformer | NLP / 序列建模 | 全注意力机制,无卷积无循环 |